From dc3f2394b72a13c94faa0893f155ac4893ee85df Mon Sep 17 00:00:00 2001 From: sirius <916108538@qq.com> Date: Thu, 17 Apr 2025 11:43:14 +0800 Subject: [PATCH 1/2] FIX:Optimistic Lock in BeforeUpdate: PK Condition Placement Affecting DB Plan Efficiency --- callbacks/update.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index 7cde7f61..be18e181 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -259,6 +259,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + priExpr := make([]clause.Expression, 0) for _, dbName := range stmt.Schema.DBNames { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { @@ -290,11 +291,26 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } else { if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + // stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + priExpr = append(priExpr, clause.Eq{Column: field.DBName, Value: value}) + } } } } + if len(priExpr) > 0 { + where := clause.Where{Exprs: priExpr} + wname := where.Name() + existWc := stmt.Clauses[wname] + existWc.Name = wname + if existingWhere, ok := existWc.Expression.(clause.Where); ok { + where.Exprs = append(priExpr, existingWhere.Exprs...) + existWc.Expression = where + stmt.Clauses[wname] = existWc + } + existWc.Expression = where + stmt.Clauses[wname] = existWc + } default: stmt.AddError(gorm.ErrInvalidData) } From ff8d0b3ddb2b1edfab014e0efa7cfa342515ea5d Mon Sep 17 00:00:00 2001 From: sirius <916108538@qq.com> Date: Thu, 17 Apr 2025 15:05:25 +0800 Subject: [PATCH 2/2] TEST:Optimistic Lock in BeforeUpdate: PK Condition Placement Affecting DB Plan Efficiency --- tests/update_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/update_test.go b/tests/update_test.go index 9eb9dbfc..a767b99d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -931,3 +931,68 @@ func TestUpdateFrom(t *testing.T) { } } } + +type Hzw struct { + Id int32 `gorm:"column:ID;primarykey"` + Type int32 `gorm:"column:TYPE;primarykey"` + Name string `gorm:"column:NAME;size:100;not null"` + Version int32 `gorm:"column:VERSION;default:0"` +} + +func (h *Hzw) BeforeUpdate(tx *gorm.DB) (err error) { + cv := h.Version + h.Version++ + nExprs := tx.Statement.BuildCondition("VERSION", cv) + newwhere := clause.Where{Exprs: nExprs} + tx.Statement.AddClause(newwhere) + return nil +} + +func TestUpdateHookOptimisticLock(t *testing.T) { + hzw := &Hzw{ + Id: 1, + Type: 2, + Name: "hzw", + Version: 0, + } + + statement := DB.Session(&gorm.Session{DryRun: true}).Save(hzw).Statement + sqlstr := statement.SQL.String() + + // Find the WHERE clause + whereIndex := strings.Index(strings.ToUpper(sqlstr), "WHERE") + if whereIndex == -1 { + t.Errorf("No WHERE clause found in the SQL statement") + return + } + whereClause := sqlstr[whereIndex+len("WHERE"):] + + // Define the expected order of conditions + expectedOrder := []string{"ID", "TYPE", "VERSION"} + + // Use regular expression to match column names + re := regexp.MustCompile(`\b(?:ID|TYPE|VERSION)\b`) + matches := re.FindAllString(whereClause, -1) + + if len(matches) < len(expectedOrder) { + t.Errorf("The actual number of WHERE conditions is less than the expected number") + return + } + + for i, expected := range expectedOrder { + if strings.ToUpper(matches[i]) != expected { + t.Errorf("The order of WHERE conditions is incorrect. Expected %s at position %d, but got %s", expected, i+1, matches[i]) + return + } + } + + vars := statement.Vars + cversion := vars[4] + vversion := vars[1] + if cversion != int32(0) { + t.Fatalf("current VERSION should be 0") + } + if vversion != int32(1) { + t.Fatalf("value VERSION should be 1") + } +}