diff --git a/clause/clause.go b/clause/clause.go index 1354fc05..ae4e49f8 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { const ( PrimaryKey string = "~~~py~~~" // primary key + PrimaryKeys string = "~~~ps~~~" // primary keys CurrentTable string = "~~~ct~~~" // current table Associations string = "~~~as~~~" // associations ) diff --git a/finisher_api.go b/finisher_api.go index 57809d17..c90e5ccd 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -119,7 +119,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { @@ -147,7 +147,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { // Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, Desc: true, }) if len(conds) > 0 { @@ -174,9 +174,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { // FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + // Use PrimaryKeys to handle composite primary key situations var ( tx = db.Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }).Session(&Session{}) queryDB = tx rowsAffected int64 @@ -200,6 +201,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat } } +find: for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -228,17 +230,76 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - if result.Statement.Schema.PrioritizedPrimaryField == nil { + if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 { tx.AddError(ErrPrimaryKeyRequired) break } - primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) - if zero { - tx.AddError(ErrPrimaryKeyRequired) - break + // The following will build a where clause like this: + // struct { + // col1 uint `gorm:"primaryKey;autoIncrement:false"` + // col2 uint `gorm:"primaryKey;autoIncrement:false"` + // col3 uint `gorm:"primaryKey;autoIncrement:false"` + // } + // last row returned was col1 = 2, col2 = 3, col3 = 5 + // where clause will be generated as follows + // WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5)) + // Detect composite primary keys + if result.Statement.Schema.PrimaryFields != nil { + pkCount := len(result.Statement.Schema.PrimaryFields) + + // Handle composite primary key Where clauses + if pkCount > 1 { + var f *schema.Field + var orClauses []clause.Expression + for i := 0; i < pkCount; i++ { + var andClauses []clause.Expression + // Build 1st column GT clause + if i == 0 { + f = result.Statement.Schema.PrimaryFields[i] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break find + } + orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build AND clause and append to OR clauses + for j := 0; j <= i; j++ { + f = result.Statement.Schema.PrimaryFields[j] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break find + } + if j == i { + // Build current outer column GT clause + andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build all other columns EQ clause + andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } + } + orClauses = append(orClauses, clause.And(andClauses...)) + } + } + queryDB = tx.Clauses(clause.Or(orClauses...)) + } else { + primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) + } + } else { + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected @@ -308,7 +369,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { @@ -348,7 +409,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) result := queryTx.Find(dest, conds...) diff --git a/statement.go b/statement.go index 74feaedd..fc745c4e 100644 --- a/statement.go +++ b/statement.go @@ -108,27 +108,62 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, v.Alias) } case clause.Column: - if v.Table != "" { - if v.Table == clause.CurrentTable { - write(v.Raw, stmt.Table) - } else { - write(v.Raw, v.Table) - } - writer.WriteByte('.') - } - - if v.Name == clause.PrimaryKey { + // Handle composite primary keys explicitly + if v.Name == clause.PrimaryKeys { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) - } else if stmt.Schema.PrioritizedPrimaryField != nil { - write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if stmt.Schema.PrimaryFields != nil { + for idx, s := range stmt.Schema.PrimaryFieldDBNames { + if idx > 0 { + writer.WriteByte(',') //nolint:typecheck,errcheck,gosec + } + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec + } + write(v.Raw, s) + } } else if len(stmt.Schema.DBNames) > 0 { + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec + } write(v.Raw, stmt.Schema.DBNames[0]) } else { stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { - write(v.Raw, v.Name) + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec + } + + if v.Name == clause.PrimaryKey { + switch { + case stmt.Schema == nil: + stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck + case stmt.Schema.PrioritizedPrimaryField != nil: + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + case len(stmt.Schema.DBNames) > 0: + write(v.Raw, stmt.Schema.DBNames[0]) + default: + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck + } + } else { + write(v.Raw, v.Name) + } } if v.Alias != "" { diff --git a/tests/query_test.go b/tests/query_test.go index 6151855e..533ddddd 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) { } } +func TestFindInBatchesCompositeKey(t *testing.T) { + coupons := []Coupon{ + {AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + } + + DB.Create(&coupons) + + var ( + results []CouponProduct + lastBatch int + ) + + if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + lastBatch = batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results)) + } + + for idx := range results { + results[idx].Desc = results[idx].Desc + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Fatalf("failed to save coupon_product, got error %v", err) + } + + return nil + }); result.Error != nil || result.RowsAffected != 12 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if lastBatch != 6 { + t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch) + } + + var count int64 + DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count) + if count != 12 { + t.Errorf("incorrect count after update, expects: %v, got %v", 12, count) + } + + DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{}) + DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{}) +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user)