From 17517b05ec1908bfd1a8be971d020d60319683a8 Mon Sep 17 00:00:00 2001 From: sgsv <-> Date: Thu, 9 May 2024 16:10:24 +0200 Subject: [PATCH] Fix association replace non-addressable panic --- association.go | 11 +++++++++++ tests/associations_has_many_test.go | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/association.go b/association.go index 7c93ebea..711df8d7 100644 --- a/association.go +++ b/association.go @@ -384,6 +384,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ ) appendToRelations := func(source, rv reflect.Value, clear bool) { + if !rv.CanAddr() { + association.Error = ErrInvalidValue + return + } + switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: switch rv.Kind() { @@ -510,6 +515,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + if association.Error != nil { + return + } // TODO support save slice data, sql with case? association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error @@ -531,6 +539,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for idx, value := range values { rv := reflect.Indirect(reflect.ValueOf(value)) appendToRelations(reflectValue, rv, clear && idx == 0) + if association.Error != nil { + return + } } if len(values) > 0 { diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index b8e8ff5e..db397eb7 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -554,3 +554,15 @@ func TestHasManyAssociationUnscoped(t *testing.T) { t.Errorf("expected %d contents, got %d", 0, len(contents)) } } + +func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) { + user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil { + t.Error("expected association error to be not nil") + } +}