From 881bd7747b95ebc13363c026bf745765acadc638 Mon Sep 17 00:00:00 2001 From: Keith Martin Date: Mon, 12 May 2025 21:36:41 +1000 Subject: [PATCH 1/5] Add handling for Composite Primary Keys to First, Last, FindInBatches, FirstOrInit and FirstOrCreate. --- clause/clause.go | 1 + finisher_api.go | 80 ++++++++++++++++++++++++++++++++++++++++++------ statement.go | 60 ++++++++++++++++++++++++++++-------- 3 files changed, 118 insertions(+), 23 deletions(-) diff --git a/clause/clause.go b/clause/clause.go index 1354fc05..ae4e49f8 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) { const ( PrimaryKey string = "~~~py~~~" // primary key + PrimaryKeys string = "~~~ps~~~" // primary keys CurrentTable string = "~~~ct~~~" // current table Associations string = "~~~as~~~" // associations ) diff --git a/finisher_api.go b/finisher_api.go index 6802945c..3d09e899 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -118,7 +118,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { @@ -146,7 +146,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { // Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, Desc: true, }) if len(conds) > 0 { @@ -173,9 +173,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { // FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + // Use PrimaryKeys to handle composite primary key situations var ( tx = db.Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }).Session(&Session{}) queryDB = tx rowsAffected int64 @@ -232,12 +233,71 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) - if zero { - tx.AddError(ErrPrimaryKeyRequired) - break + // The following will build a where clause like this: + // struct { + // col1 uint `gorm:"primaryKey;autoIncrement:false"` + // col2 uint `gorm:"primaryKey;autoIncrement:false"` + // col3 uint `gorm:"primaryKey;autoIncrement:false"` + // } + // last row returned was col1 = 2, col2 = 3, col3 = 5 + // where clause will be generated as follows + // WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5)) + // Detect composite primary keys + if result.Statement.Schema.PrimaryFields != nil { + pkCount := len(result.Statement.Schema.PrimaryFields) + + // Handle composite primary key Where clauses + if pkCount > 1 { + var f *schema.Field + var orClauses []clause.Expression + for i := 0; i < pkCount; i++ { + var andClauses []clause.Expression + // Build 1st column GT clause + if i == 0 { + f = result.Statement.Schema.PrimaryFields[i] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build AND clause and append to OR clauses + for j := 0; j <= i; j++ { + f = result.Statement.Schema.PrimaryFields[j] + primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + if j == i { + // Build current outer column GT clause + andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } else { + // Build all other columns EQ clause + andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) + } + } + orClauses = append(orClauses, clause.And(andClauses...)) + } + } + queryDB = tx.Clauses(clause.Or(orClauses...)) + } else { + primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) + } + } else { + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected @@ -307,7 +367,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { @@ -347,7 +407,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys}, }) result := queryTx.Find(dest, conds...) diff --git a/statement.go b/statement.go index 39e05d09..c23afe21 100644 --- a/statement.go +++ b/statement.go @@ -105,27 +105,61 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, v.Alias) } case clause.Column: - if v.Table != "" { - if v.Table == clause.CurrentTable { - write(v.Raw, stmt.Table) - } else { - write(v.Raw, v.Table) - } - writer.WriteByte('.') - } - - if v.Name == clause.PrimaryKey { + // Handle composite primary keys explicitly + if v.Name == clause.PrimaryKeys { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) - } else if stmt.Schema.PrioritizedPrimaryField != nil { - write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if stmt.Schema.PrimaryFields != nil { + for idx, s := range stmt.Schema.PrimaryFieldDBNames { + if idx > 0 { + writer.WriteByte(',') + } + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + write(v.Raw, s) + } } else if len(stmt.Schema.DBNames) > 0 { + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } write(v.Raw, stmt.Schema.DBNames[0]) } else { stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { - write(v.Raw, v.Name) + if v.Table != "" { + if v.Table == clause.CurrentTable { + write(v.Raw, stmt.Table) + } else { + write(v.Raw, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck + } + } else { + write(v.Raw, v.Name) + } } if v.Alias != "" { From f121f183b08dbc21c4cd875268fcfd0f40f3cbe7 Mon Sep 17 00:00:00 2001 From: Keith Martin Date: Mon, 12 May 2025 21:38:37 +1000 Subject: [PATCH 2/5] Add test for Composite Key FindInBatches --- tests/query_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/query_test.go b/tests/query_test.go index 566763c5..a627f4c5 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) { } } +func TestFindInBatchesCompositeKey(t *testing.T) { + coupons := []Coupon{ + {AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + {AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{ + {ProductId: "1", Desc: "find_in_batches"}, + {ProductId: "2", Desc: "find_in_batches"}, + {ProductId: "3", Desc: "find_in_batches"}, + }}, + } + + DB.Create(&coupons) + + var ( + results []CouponProduct + lastBatch int + ) + + if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + lastBatch = batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results)) + } + + for idx := range results { + results[idx].Desc = results[idx].Desc + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Fatalf("failed to save coupon_product, got error %v", err) + } + + return nil + }); result.Error != nil || result.RowsAffected != 12 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if lastBatch != 6 { + t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch) + } + + var count int64 + DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count) + if count != 12 { + t.Errorf("incorrect count after update, expects: %v, got %v", 12, count) + } + + DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{}) + DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{}) +} + func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) From 0fe079686b469e3d5b2f65dbdc48a797deb20dfe Mon Sep 17 00:00:00 2001 From: Keith Martin Date: Mon, 12 May 2025 22:21:23 +1000 Subject: [PATCH 3/5] Correct check for single column pk to handle multi column pk --- finisher_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 3d09e899..3c0df360 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -228,7 +228,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - if result.Statement.Schema.PrioritizedPrimaryField == nil { + if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 { tx.AddError(ErrPrimaryKeyRequired) break } From c9984634acb437009e9abdea8a8ed9fa37c22d06 Mon Sep 17 00:00:00 2001 From: Keith Martin Date: Mon, 12 May 2025 23:05:34 +1000 Subject: [PATCH 4/5] Address lint issues --- finisher_api.go | 11 ++++++----- statement.go | 19 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 3c0df360..8773236f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -200,6 +200,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat } } +find: for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -257,8 +258,8 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat f = result.Statement.Schema.PrimaryFields[i] primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) if zero { - tx.AddError(ErrPrimaryKeyRequired) - break + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break find } orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue}) } else { @@ -267,8 +268,8 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat f = result.Statement.Schema.PrimaryFields[j] primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) if zero { - tx.AddError(ErrPrimaryKeyRequired) - break + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec + break find } if j == i { // Build current outer column GT clause @@ -285,7 +286,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat } else { primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) if zero { - tx.AddError(ErrPrimaryKeyRequired) + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec break } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue}) diff --git a/statement.go b/statement.go index c23afe21..eddd8dee 100644 --- a/statement.go +++ b/statement.go @@ -112,7 +112,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } else if stmt.Schema.PrimaryFields != nil { for idx, s := range stmt.Schema.PrimaryFieldDBNames { if idx > 0 { - writer.WriteByte(',') + writer.WriteByte(',') //nolint:typecheck,errcheck,gosec } if v.Table != "" { if v.Table == clause.CurrentTable { @@ -120,7 +120,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } else { write(v.Raw, v.Table) } - writer.WriteByte('.') + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec } write(v.Raw, s) } @@ -131,7 +131,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } else { write(v.Raw, v.Table) } - writer.WriteByte('.') + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec } write(v.Raw, stmt.Schema.DBNames[0]) } else { @@ -148,14 +148,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } if v.Name == clause.PrimaryKey { - if stmt.Schema == nil { - stmt.DB.AddError(ErrModelValueRequired) - } else if stmt.Schema.PrioritizedPrimaryField != nil { + switch { + case stmt.Schema == nil: + stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec + case stmt.Schema.PrioritizedPrimaryField != nil: write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) - } else if len(stmt.Schema.DBNames) > 0 { + case len(stmt.Schema.DBNames) > 0: write(v.Raw, stmt.Schema.DBNames[0]) - } else { - stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck + default: + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec } } else { write(v.Raw, v.Name) From 904b1289a4e883b6a43dc667c49ebd0178b57a7f Mon Sep 17 00:00:00 2001 From: Keith Martin Date: Mon, 12 May 2025 23:16:14 +1000 Subject: [PATCH 5/5] Address lint issues --- finisher_api.go | 2 +- statement.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 8773236f..8159c2bb 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,7 +294,7 @@ find: } else { primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) if zero { - tx.AddError(ErrPrimaryKeyRequired) + tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec break } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) diff --git a/statement.go b/statement.go index eddd8dee..c35e0b3a 100644 --- a/statement.go +++ b/statement.go @@ -144,19 +144,19 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } else { write(v.Raw, v.Table) } - writer.WriteByte('.') + writer.WriteByte('.') //nolint:typecheck,errcheck,gosec } if v.Name == clause.PrimaryKey { switch { case stmt.Schema == nil: - stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec + stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck case stmt.Schema.PrioritizedPrimaryField != nil: write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) case len(stmt.Schema.DBNames) > 0: write(v.Raw, stmt.Schema.DBNames[0]) default: - stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck } } else { write(v.Raw, v.Name)