From feacf577b397f0f1a786d1d6933f32d1ea624697 Mon Sep 17 00:00:00 2001 From: demoManito <1430482733@qq.com> Date: Wed, 21 Aug 2024 21:45:57 +0800 Subject: [PATCH] feat: BeforeUpdate hook supports update using struct --- callbacks/update.go | 48 +++++++++++++++++++++++++++++++++++---------- tests/hooks_test.go | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 7cde7f61..7e0ebd9a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,22 +32,50 @@ func SetupUpdateReflectValue(db *gorm.DB) { // BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { - callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + var ( + beforeSaveInterface BeforeSaveInterface + isBeforeSaveHook bool + beforeUpdateInterface BeforeUpdateInterface + isBeforeUpdateHook bool + ) if db.Statement.Schema.BeforeSave { - if i, ok := value.(BeforeSaveInterface); ok { - called = true - db.AddError(i.BeforeSave(tx)) - } + beforeSaveInterface, isBeforeSaveHook = value.(BeforeSaveInterface) } - if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(BeforeUpdateInterface); ok { - called = true - db.AddError(i.BeforeUpdate(tx)) + beforeUpdateInterface, isBeforeUpdateHook = value.(BeforeUpdateInterface) + } + if !isBeforeSaveHook && !isBeforeUpdateHook { + return false + } + + // save a snapshot of the struct before the hook was called + rv := reflect.Indirect(reflect.ValueOf(value)) + rvSnapshot := reflect.New(rv.Type()).Elem() + rvSnapshot.Set(rv) + + if isBeforeSaveHook { + db.AddError(beforeSaveInterface.BeforeSave(tx)) + } + if isBeforeUpdateHook { + db.AddError(beforeUpdateInterface.BeforeUpdate(tx)) + } + + for _, field := range db.Statement.Schema.Fields { + if field.PrimaryKey { + continue + } + dbFieldName, ok := field.TagSettings["COLUMN"] + if !ok { + continue + } + // compare with the snapshot and update the field if there is a difference + if !reflect.DeepEqual(rv.FieldByName(field.Name).Interface(), rvSnapshot.FieldByName(field.Name).Interface()) { + db.Statement.SetColumn(dbFieldName, rv.FieldByName(field.Name).Interface()) } } - return called + return true }) } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 04f62bde..a3cbbc72 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -609,3 +609,44 @@ func TestPropagateUnscoped(t *testing.T) { t.Fatalf("unscoped did not propagate") } } + +type StructUpdate struct { + ID uint `gorm:"column:id;primary_key"` + Version int `gorm:"column:version"` + Name string `gorm:"column:name"` +} + +func (StructUpdate) TableName() string { + return "struct_updates" +} + +func (su *StructUpdate) BeforeUpdate(*gorm.DB) error { + su.Version++ + return nil +} + +func TestBeforeUpdateWithStructColumn(t *testing.T) { + DB.Migrator().DropTable(&StructUpdate{}) + DB.AutoMigrate(&StructUpdate{}) + + su := StructUpdate{ + ID: 1, + Version: 1, + } + err := DB.Create(&su).Error + if err != nil { + t.Fatalf("create struct failed: %v", err) + } + + err = DB.Model(&su).Update("name", "demoManito").Error + if err != nil { + t.Fatalf("update struct failed: %v", err) + } + if su.Version != 2 { + t.Fatalf("update version failed: %v", su.Version) + } + err = DB.Find(&su, "id = ?", 1).Error + if err != nil { + t.Fatalf("find struct failed: %v", err) + } +}