diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index 14b4bf9c..900f0fcb 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -13,14 +13,15 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } db.Statement.CurDestIndex++ } case reflect.Struct: - if !db.Statement.ReflectValue.CanAddr() { - db.Statement.ReflectValue = reflect.New(db.Statement.ReflectValue.Type()).Elem() + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) } - fc(db.Statement.ReflectValue.Addr().Interface(), tx) } } } diff --git a/callbacks/update.go b/callbacks/update.go index b596df9a..fe6f0994 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -137,7 +137,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } } } case reflect.Struct: diff --git a/schema/utils.go b/schema/utils.go index acf1a739..09c1ae3b 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -134,7 +134,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, elem := reflectValue.Index(i) elemKey := elem.Interface() if elem.Kind() != reflect.Ptr { - elemKey = elem.Addr().Interface() + if elem.CanAddr() { + elemKey = elem.Addr().Interface() + } else { + elemKey = elem.Interface() + } } if _, ok := loaded[elemKey]; ok { diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 6ef1151b..13c54dab 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -540,7 +540,17 @@ func TestUpdateCallbacks(t *testing.T) { } DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2") + if beforeUpdateCall != 1 { + t.Fatalf("before update should not be called") + } + + DB.Model([1]*Product5{&p}).Update("name", "update_name_3") if beforeUpdateCall != 2 { t.Fatalf("before update should be called") } + + DB.Model([1]Product5{p}).Update("name", "update_name_4") + if beforeUpdateCall != 2 { + t.Fatalf("before update should not be called") + } }