feat: BeforeUpdate hook supports update using struct
This commit is contained in:
parent
4a50b36f63
commit
feacf577b3
@ -32,22 +32,50 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
|||||||
// BeforeUpdate before update hooks
|
// BeforeUpdate before update hooks
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
|
var (
|
||||||
|
beforeSaveInterface BeforeSaveInterface
|
||||||
|
isBeforeSaveHook bool
|
||||||
|
beforeUpdateInterface BeforeUpdateInterface
|
||||||
|
isBeforeUpdateHook bool
|
||||||
|
)
|
||||||
if db.Statement.Schema.BeforeSave {
|
if db.Statement.Schema.BeforeSave {
|
||||||
if i, ok := value.(BeforeSaveInterface); ok {
|
beforeSaveInterface, isBeforeSaveHook = value.(BeforeSaveInterface)
|
||||||
called = true
|
|
||||||
db.AddError(i.BeforeSave(tx))
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema.BeforeUpdate {
|
if db.Statement.Schema.BeforeUpdate {
|
||||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
beforeUpdateInterface, isBeforeUpdateHook = value.(BeforeUpdateInterface)
|
||||||
called = true
|
}
|
||||||
db.AddError(i.BeforeUpdate(tx))
|
if !isBeforeSaveHook && !isBeforeUpdateHook {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// save a snapshot of the struct before the hook was called
|
||||||
|
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
rvSnapshot := reflect.New(rv.Type()).Elem()
|
||||||
|
rvSnapshot.Set(rv)
|
||||||
|
|
||||||
|
if isBeforeSaveHook {
|
||||||
|
db.AddError(beforeSaveInterface.BeforeSave(tx))
|
||||||
|
}
|
||||||
|
if isBeforeUpdateHook {
|
||||||
|
db.AddError(beforeUpdateInterface.BeforeUpdate(tx))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range db.Statement.Schema.Fields {
|
||||||
|
if field.PrimaryKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dbFieldName, ok := field.TagSettings["COLUMN"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// compare with the snapshot and update the field if there is a difference
|
||||||
|
if !reflect.DeepEqual(rv.FieldByName(field.Name).Interface(), rvSnapshot.FieldByName(field.Name).Interface()) {
|
||||||
|
db.Statement.SetColumn(dbFieldName, rv.FieldByName(field.Name).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return called
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -609,3 +609,44 @@ func TestPropagateUnscoped(t *testing.T) {
|
|||||||
t.Fatalf("unscoped did not propagate")
|
t.Fatalf("unscoped did not propagate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StructUpdate struct {
|
||||||
|
ID uint `gorm:"column:id;primary_key"`
|
||||||
|
Version int `gorm:"column:version"`
|
||||||
|
Name string `gorm:"column:name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (StructUpdate) TableName() string {
|
||||||
|
return "struct_updates"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (su *StructUpdate) BeforeUpdate(*gorm.DB) error {
|
||||||
|
su.Version++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBeforeUpdateWithStructColumn(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&StructUpdate{})
|
||||||
|
DB.AutoMigrate(&StructUpdate{})
|
||||||
|
|
||||||
|
su := StructUpdate{
|
||||||
|
ID: 1,
|
||||||
|
Version: 1,
|
||||||
|
}
|
||||||
|
err := DB.Create(&su).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create struct failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(&su).Update("name", "demoManito").Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("update struct failed: %v", err)
|
||||||
|
}
|
||||||
|
if su.Version != 2 {
|
||||||
|
t.Fatalf("update version failed: %v", su.Version)
|
||||||
|
}
|
||||||
|
err = DB.Find(&su, "id = ?", 1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("find struct failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user