diff --git a/clause/clause.go b/clause/clause.go index 1354fc05..ae4e49f8 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { const ( PrimaryKey string = "~~~py~~~" // primary key + PrimaryKeys string = "~~~ps~~~" // primary keys CurrentTable string = "~~~ct~~~" // current table Associations string = "~~~as~~~" // associations ) diff --git a/finisher_api.go b/finisher_api.go index 6802945c..3d09e899 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -118,7 +118,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { @@ -146,7 +146,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { // Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, Desc: true, }) if len(conds) > 0 { @@ -173,9 +173,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { // FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + // Use PrimaryKeys to handle composite primary key situations var ( tx = db.Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }).Session(&Session{}) queryDB = tx rowsAffected int64 @@ -232,12 +233,71 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) - if zero { - tx.AddError(ErrPrimaryKeyRequired) - break + // The following will build a where clause like this: + // struct { + // col1 uint `gorm:"primaryKey;autoIncrement:false"` + // col2 uint `gorm:"primaryKey;autoIncrement:false"` + // col3 uint `gorm:"primaryKey;autoIncrement:false"` + // } + // last row returned was col1 = 2, col2 = 3, col3 = 5 + // where clause will be generated as follows + // WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5)) + // Detect composite primary keys + if result.Statement.Schema.PrimaryFields != nil { + pkCount := len(result.Statement.Schema.PrimaryFields) + + // Handle composite primary key Where clauses + if pkCount > 1 { + var f *schema.Field + var orClauses []clause.Expression + for i := 0; i < pkCount; i++ { + var andClauses []clause.Expression + // Build 1st column GT clause + if i == 0 { + f = result.Statement.Schema.PrimaryFields[i] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build AND clause and append to OR clauses + for j := 0; j <= i; j++ { + f = result.Statement.Schema.PrimaryFields[j] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + if j == i { + // Build current outer column GT clause + andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build all other columns EQ clause + andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } + } + orClauses = append(orClauses, clause.And(andClauses...)) + } + } + queryDB = tx.Clauses(clause.Or(orClauses...)) + } else { + primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) + } + } else { + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected @@ -307,7 +367,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { @@ -347,7 +407,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) result := queryTx.Find(dest, conds...) diff --git a/statement.go b/statement.go index 39e05d09..c23afe21 100644 --- a/statement.go +++ b/statement.go @@ -105,27 +105,61 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, v.Alias) } case clause.Column: - if v.Table != "" { - if v.Table == clause.CurrentTable { - write(v.Raw, stmt.Table) - } else { - write(v.Raw, v.Table) - } - writer.WriteByte('.') - } - - if v.Name == clause.PrimaryKey { + // Handle composite primary keys explicitly + if v.Name == clause.PrimaryKeys { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) - } else if stmt.Schema.PrioritizedPrimaryField != nil { - write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if stmt.Schema.PrimaryFields != nil { + for idx, s := range stmt.Schema.PrimaryFieldDBNames { + if idx > 0 { + writer.WriteByte(',') + } + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + write(v.Raw, s) + } } else if len(stmt.Schema.DBNames) > 0 { + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } write(v.Raw, stmt.Schema.DBNames[0]) } else { stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { - write(v.Raw, v.Name) + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck + } + } else { + write(v.Raw, v.Name) + } } if v.Alias != "" {