diff --git a/association.go b/association.go index 342dd6cd..81c25f14 100644 --- a/association.go +++ b/association.go @@ -30,19 +30,44 @@ func (association *Association) Append(values ...interface{}) *Association { scope := association.Scope field := association.Field + createJoinTable := func(reflectValue reflect.Value) { + var value = reflectValue.Interface() + if reflectValue.Kind() != reflect.Ptr { + reflectPtr := reflect.New(reflectValue.Type()) + reflectPtr.Elem().Set(reflectValue) + value = reflectPtr.Interface() + } + + if scope.New(value).PrimaryKeyZero() { + scope.NewDB().Save(value) + } + + relationship := association.Field.Relationship + association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, value)) + + result := reflect.ValueOf(value) + fieldElemType := field.Field.Type().Elem() + if result.Type().AssignableTo(fieldElemType) { + field.Set(reflect.Append(field.Field, result)) + } else if result.Type().Elem().AssignableTo(fieldElemType) { + field.Set(reflect.Append(field.Field, result.Elem())) + } + } + for _, value := range values { - reflectvalue := reflect.Indirect(reflect.ValueOf(value)) - if reflectvalue.Kind() == reflect.Struct { - field.Set(reflect.Append(field.Field, reflectvalue)) - } else if reflectvalue.Kind() == reflect.Slice { - field.Set(reflect.AppendSlice(field.Field, reflectvalue)) + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + if reflectValue.Kind() == reflect.Struct { + createJoinTable(reflectValue) + } else if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + createJoinTable(reflectValue.Index(i)) + } } else { association.setErr(errors.New("invalid association type")) } } - scope.Search.Select(association.Column) - scope.callCallbacks(scope.db.parent.callback.updates) - return association.setErr(scope.db.Error) + return association } func (association *Association) Delete(values ...interface{}) *Association { @@ -230,7 +255,7 @@ func toQueryMarks(primaryValues [][]interface{}) string { for _, primaryValue := range primaryValues { var marks []string - for _,_ = range primaryValue { + for _, _ = range primaryValue { marks = append(marks, "?") } diff --git a/association_test.go b/association_test.go index fa018bdd..eb606bc5 100644 --- a/association_test.go +++ b/association_test.go @@ -160,7 +160,7 @@ func TestManyToMany(t *testing.T) { languageA := Language{Name: "AA"} DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA) + DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) languageC := Language{Name: "CC"} DB.Save(&languageC)