From aa3fd6de13fee7e0ae715eeaad3bc2f329db2366 Mon Sep 17 00:00:00 2001 From: Jess Smith Date: Sat, 10 Feb 2018 01:26:01 -0500 Subject: [PATCH 1/8] Sort column names before generating SQL in `DB.UpdateColumns` (#1734) --- callback_update.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/callback_update.go b/callback_update.go index 6948439f..373bd726 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { From 9235b47ea28d816ef25d6bf4e037ccb5c7c7096b Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Wed, 4 Oct 2017 08:19:16 +0000 Subject: [PATCH 2/8] Allows foreign keys to be saved without saving the assoication when specified #1628 --- callback_save.go | 53 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/callback_save.go b/callback_save.go index f4bc918e..ad4eda2f 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,35 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship - } - } - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field); + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { @@ -47,7 +46,7 @@ func saveAfterAssociationsCallback(scope *Scope) { return } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && + if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field From 9f409820dfdfc2ab7cb20a56d4cefdf1a111c315 Mon Sep 17 00:00:00 2001 From: joe-at-startupmedia Date: Tue, 10 Oct 2017 18:20:56 +0000 Subject: [PATCH 3/8] Formatting code with gomt --- callback_save.go | 50 ++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/callback_save.go b/callback_save.go index ad4eda2f..fa32c907 100644 --- a/callback_save.go +++ b/callback_save.go @@ -11,34 +11,34 @@ func commitOrRollbackTransactionCallback(scope *Scope) { } func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship - } - return false, field.Relationship - } - return false, nil + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { + return true, field.Relationship + } + return false, field.Relationship + } + return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { - for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field); - if relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { - scope.Err(scope.NewDB().Save(fieldValue).Error) - } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } + for _, field := range scope.Fields() { + ok, relationship := saveFieldAsAssociation(scope, field) + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + if ok && scope.shouldSaveAssociations() { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } + } + } + } + } } func saveAfterAssociationsCallback(scope *Scope) { From 63cb513b4978a49870ff20d27fb18c721f64d977 Mon Sep 17 00:00:00 2001 From: Ezequiel Muns Date: Wed, 1 Nov 2017 18:45:08 +0100 Subject: [PATCH 4/8] Tests for saving foreign key when save_associations:false --- association_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/association_test.go b/association_test.go index c84f84ed..f37047d1 100644 --- a/association_test.go +++ b/association_test.go @@ -902,6 +902,20 @@ func TestSkipSaveAssociation(t *testing.T) { DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + t.Errorf("Company skip_save_association should not have been saved") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "skip_save_association"} + DB.Save(&company) + company.Name = "skip_save_association_modified" + user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + DB.Save(&user) + + if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { + t.Errorf("Company skip_save_association should not have been updated") + } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") } } From 43dc867644b879f8f87fd0598ac0b459232d9293 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 15:16:20 +0800 Subject: [PATCH 5/8] Allow save association relations w/o saving association --- association_test.go | 2 +- callback_save.go | 31 ++++++++++++++++++------------- scope.go | 2 +- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/association_test.go b/association_test.go index f37047d1..34822dbc 100644 --- a/association_test.go +++ b/association_test.go @@ -909,7 +909,7 @@ func TestSkipSaveAssociation(t *testing.T) { company := Company{Name: "skip_save_association"} DB.Save(&company) company.Name = "skip_save_association_modified" - user := User{Name: "jinzhu", CompanyID: company.ID, Company: company} + user := User{Name: "jinzhu", Company: company} DB.Save(&user) if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { diff --git a/callback_save.go b/callback_save.go index fa32c907..544354d0 100644 --- a/callback_save.go +++ b/callback_save.go @@ -12,22 +12,25 @@ func commitOrRollbackTransactionCallback(scope *Scope) { func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - return true, field.Relationship + if field.Relationship != nil { + if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { + return true, field.Relationship + } + return false, field.Relationship } - return false, field.Relationship } return false, nil } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - ok, relationship := saveFieldAsAssociation(scope, field) - if relationship != nil && relationship.Kind == "belongs_to" { + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - if ok && scope.shouldSaveAssociations() { + + if allowSaveAssociation { scope.Err(scope.NewDB().Save(fieldValue).Error) } + if len(relationship.ForeignFieldNames) != 0 { // set value's foreign key for idx, fieldName := range relationship.ForeignFieldNames { @@ -42,11 +45,8 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship != nil && + if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field @@ -70,9 +70,11 @@ func saveAfterAssociationsCallback(scope *Scope) { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(newDB.Save(elem).Error) + if allowSaveAssociation { + scope.Err(newDB.Save(elem).Error) + } - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) } } @@ -91,7 +93,10 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - scope.Err(scope.NewDB().Save(elem).Error) + + if allowSaveAssociation { + scope.Err(scope.NewDB().Save(elem).Error) + } } } } diff --git a/scope.go b/scope.go index a10cb3a2..9ae33913 100644 --- a/scope.go +++ b/scope.go @@ -993,7 +993,7 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { +func (scope *Scope) allowSaveAssociations() bool { if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { if v, ok := saveAssociations.(bool); ok && !v { return false From b2b568daa8e27966c39c942e5aefc74bcc8af88d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 16:47:48 +0800 Subject: [PATCH 6/8] Add tag association_autoupdate, association_autocreate, association_save_reference support --- association_test.go | 147 +++++++++++++++++++++++++++++++++++++++--- callback_save.go | 153 +++++++++++++++++++++++++++++++------------- query_test.go | 2 +- scope.go | 12 ---- 4 files changed, 248 insertions(+), 66 deletions(-) diff --git a/association_test.go b/association_test.go index 34822dbc..60d0cf48 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,27 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been saved") + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") } // if foreign key is set, this should be saved even if association isn't - company := Company{Name: "skip_save_association"} + company := Company{Name: "auto_save_association"} DB.Save(&company) - company.Name = "skip_save_association_modified" + + company.Name = "auto_save_association_new_name" user := User{Name: "jinzhu", Company: company} + DB.Save(&user) - if !DB.Where("name = ?", "skip_save_association_modified").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not have been updated") + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") } + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { t.Errorf("User's foreign key should have been saved") } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } + + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") + } } diff --git a/callback_save.go b/callback_save.go index 544354d0..243c986e 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,33 +13,79 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if field.Relationship != nil { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; (!ok || (value != "false" && value != "skip")) && scope.allowSaveAssociations() { - return true, field.Relationship +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } - return false, field.Relationship } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) - if allowSaveAssociation { + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(fieldValue).Error) } - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -46,8 +95,9 @@ func saveBeforeAssociationsCallback(scope *Scope) { func saveAfterAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if allowSaveAssociation, relationship := saveFieldAsAssociation(scope, field); relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -57,7 +107,41 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } + + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -69,32 +153,13 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } - - if allowSaveAssociation { - scope.Err(newDB.Save(elem).Error) - } - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil && !newScope.PrimaryKeyZero() { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } } - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - - if allowSaveAssociation { + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(elem).Error) } } diff --git a/query_test.go b/query_test.go index def84e04..98721800 100644 --- a/query_test.go +++ b/query_test.go @@ -389,7 +389,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") diff --git a/scope.go b/scope.go index 9ae33913..125e02b0 100644 --- a/scope.go +++ b/scope.go @@ -993,18 +993,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) allowSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) From 2940c553eb9763e966effbdca702e2d5b2b255da Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 18:01:41 +0800 Subject: [PATCH 7/8] Add DB setting gorm:association_save_reference --- callback_save.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callback_save.go b/callback_save.go index 243c986e..ef267141 100644 --- a/callback_save.go +++ b/callback_save.go @@ -53,7 +53,9 @@ func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCrea autoCreate = checkTruth(value) } - if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { saveReference = checkTruth(value) } } From c6ce739b2a4d3b26af9326a31723883b4f136a74 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 10 Feb 2018 19:25:58 +0800 Subject: [PATCH 8/8] Convert auto_increment's value to lower case when checking its value --- dialect_common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dialect_common.go b/dialect_common.go index 1e5e3b61..fbbaef33 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -40,7 +40,7 @@ func (commonDialect) Quote(key string) string { func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" + return strings.ToLower(value) != "false" } return field.IsPrimaryKey }