From ff8d0b3ddb2b1edfab014e0efa7cfa342515ea5d Mon Sep 17 00:00:00 2001 From: sirius <916108538@qq.com> Date: Thu, 17 Apr 2025 15:05:25 +0800 Subject: [PATCH] TEST:Optimistic Lock in BeforeUpdate: PK Condition Placement Affecting DB Plan Efficiency --- tests/update_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/update_test.go b/tests/update_test.go index 9eb9dbfc..a767b99d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -931,3 +931,68 @@ func TestUpdateFrom(t *testing.T) { } } } + +type Hzw struct { + Id int32 `gorm:"column:ID;primarykey"` + Type int32 `gorm:"column:TYPE;primarykey"` + Name string `gorm:"column:NAME;size:100;not null"` + Version int32 `gorm:"column:VERSION;default:0"` +} + +func (h *Hzw) BeforeUpdate(tx *gorm.DB) (err error) { + cv := h.Version + h.Version++ + nExprs := tx.Statement.BuildCondition("VERSION", cv) + newwhere := clause.Where{Exprs: nExprs} + tx.Statement.AddClause(newwhere) + return nil +} + +func TestUpdateHookOptimisticLock(t *testing.T) { + hzw := &Hzw{ + Id: 1, + Type: 2, + Name: "hzw", + Version: 0, + } + + statement := DB.Session(&gorm.Session{DryRun: true}).Save(hzw).Statement + sqlstr := statement.SQL.String() + + // Find the WHERE clause + whereIndex := strings.Index(strings.ToUpper(sqlstr), "WHERE") + if whereIndex == -1 { + t.Errorf("No WHERE clause found in the SQL statement") + return + } + whereClause := sqlstr[whereIndex+len("WHERE"):] + + // Define the expected order of conditions + expectedOrder := []string{"ID", "TYPE", "VERSION"} + + // Use regular expression to match column names + re := regexp.MustCompile(`\b(?:ID|TYPE|VERSION)\b`) + matches := re.FindAllString(whereClause, -1) + + if len(matches) < len(expectedOrder) { + t.Errorf("The actual number of WHERE conditions is less than the expected number") + return + } + + for i, expected := range expectedOrder { + if strings.ToUpper(matches[i]) != expected { + t.Errorf("The order of WHERE conditions is incorrect. Expected %s at position %d, but got %s", expected, i+1, matches[i]) + return + } + } + + vars := statement.Vars + cversion := vars[4] + vversion := vars[1] + if cversion != int32(0) { + t.Fatalf("current VERSION should be 0") + } + if vversion != int32(1) { + t.Fatalf("value VERSION should be 1") + } +}