Add handling for Composite Primary Keys to First, Last, FindInBatches, FirstOrInit and FirstOrCreate.

This commit is contained in:
Keith Martin 2025-05-12 21:36:41 +10:00
parent e5b867e785
commit 881bd7747b
3 changed files with 118 additions and 23 deletions

View File

@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) {
const ( const (
PrimaryKey string = "~~~py~~~" // primary key PrimaryKey string = "~~~py~~~" // primary key
PrimaryKeys string = "~~~ps~~~" // primary keys
CurrentTable string = "~~~ct~~~" // current table CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations Associations string = "~~~as~~~" // associations
) )

View File

@ -118,7 +118,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First finds the first record ordered by primary key, matching given conditions conds // First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
}) })
if len(conds) > 0 { if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
@ -146,7 +146,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
// Last finds the last record ordered by primary key, matching given conditions conds // Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
Desc: true, Desc: true,
}) })
if len(conds) > 0 { if len(conds) > 0 {
@ -173,9 +173,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
// FindInBatches finds all records in batches of batchSize // FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
// Use PrimaryKeys to handle composite primary key situations
var ( var (
tx = db.Order(clause.OrderByColumn{ tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
}).Session(&Session{}) }).Session(&Session{})
queryDB = tx queryDB = tx
rowsAffected int64 rowsAffected int64
@ -232,12 +233,71 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) // The following will build a where clause like this:
if zero { // struct {
tx.AddError(ErrPrimaryKeyRequired) // col1 uint `gorm:"primaryKey;autoIncrement:false"`
break // col2 uint `gorm:"primaryKey;autoIncrement:false"`
// col3 uint `gorm:"primaryKey;autoIncrement:false"`
// }
// last row returned was col1 = 2, col2 = 3, col3 = 5
// where clause will be generated as follows
// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
// Detect composite primary keys
if result.Statement.Schema.PrimaryFields != nil {
pkCount := len(result.Statement.Schema.PrimaryFields)
// Handle composite primary key Where clauses
if pkCount > 1 {
var f *schema.Field
var orClauses []clause.Expression
for i := 0; i < pkCount; i++ {
var andClauses []clause.Expression
// Build 1st column GT clause
if i == 0 {
f = result.Statement.Schema.PrimaryFields[i]
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
} else {
// Build AND clause and append to OR clauses
for j := 0; j <= i; j++ {
f = result.Statement.Schema.PrimaryFields[j]
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
if j == i {
// Build current outer column GT clause
andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
} else {
// Build all other columns EQ clause
andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
}
}
orClauses = append(orClauses, clause.And(andClauses...))
}
}
queryDB = tx.Clauses(clause.Or(orClauses...))
} else {
primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue})
}
} else {
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
tx.RowsAffected = rowsAffected tx.RowsAffected = rowsAffected
@ -307,7 +367,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
}) })
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
@ -347,7 +407,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
}) })
result := queryTx.Find(dest, conds...) result := queryTx.Find(dest, conds...)

View File

@ -105,27 +105,61 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write(v.Raw, v.Alias) write(v.Raw, v.Alias)
} }
case clause.Column: case clause.Column:
if v.Table != "" { // Handle composite primary keys explicitly
if v.Table == clause.CurrentTable { if v.Name == clause.PrimaryKeys {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
if v.Name == clause.PrimaryKey {
if stmt.Schema == nil { if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired) stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil { } else if stmt.Schema.PrimaryFields != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) for idx, s := range stmt.Schema.PrimaryFieldDBNames {
if idx > 0 {
writer.WriteByte(',')
}
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
write(v.Raw, s)
}
} else if len(stmt.Schema.DBNames) > 0 { } else if len(stmt.Schema.DBNames) > 0 {
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
write(v.Raw, stmt.Schema.DBNames[0]) write(v.Raw, stmt.Schema.DBNames[0])
} else { } else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
} }
} else { } else {
write(v.Raw, v.Name) if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
if v.Name == clause.PrimaryKey {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else {
write(v.Raw, v.Name)
}
} }
if v.Alias != "" { if v.Alias != "" {