diff --git a/association.go b/association.go index 3d522ccc..8c6d9864 100644 --- a/association.go +++ b/association.go @@ -107,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -173,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { diff --git a/association_test.go b/association_test.go index c84f84ed..60d0cf48 100644 --- a/association_test.go +++ b/association_test.go @@ -885,7 +885,7 @@ func TestHasManyChildrenWithOneStruct(t *testing.T) { DB.Save(&category) } -func TestSkipSaveAssociation(t *testing.T) { +func TestAutoSaveBelongsToAssociation(t *testing.T) { type Company struct { gorm.Model Name string @@ -895,13 +895,156 @@ func TestSkipSaveAssociation(t *testing.T) { gorm.Model Name string CompanyID uint - Company Company `gorm:"save_associations:false"` + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` } + + DB.Where("name = ?", "auto_save_association").Delete(&Company{}) DB.AutoMigrate(&Company{}, &User{}) - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_association"}}) - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") + if !DB.Where("name = ?", "auto_save_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association should not have been saved when autosave is false") + } + + // if foreign key is set, this should be saved even if association isn't + company := Company{Name: "auto_save_association"} + DB.Save(&company) + + company.Name = "auto_save_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if DB.Where("id = ? AND company_id = ?", user.ID, company.ID).First(&User{}).RecordNotFound() { + t.Errorf("User's foreign key should have been saved") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_association_2 should been created when autocreate is true") + } + + user2.Company.Name = "auto_save_association_2_newname" + DB.Set("gorm:association_autoupdate", true).Save(&user2) + + if DB.Where("name = ?", "auto_save_association_2_newname").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } +} + +func TestAutoSaveHasOneAssociation(t *testing.T) { + type Company struct { + gorm.Model + UserID uint + Name string + } + + type User struct { + gorm.Model + Name string + Company Company `gorm:"association_autoupdate:false;association_autocreate:false;"` + } + + DB.Where("name = ?", "auto_save_has_one_association").Delete(&Company{}) + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Company: Company{Name: "auto_save_has_one_association"}}) + + if !DB.Where("name = ?", "auto_save_has_one_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_has_one_association"} + DB.Save(&company) + + company.Name = "auto_save_has_one_association_new_name" + user := User{Name: "jinzhu", Company: company} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_has_one_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if user.Company.UserID == 0 { + t.Errorf("UserID should be assigned") + } + + company.Name = "auto_save_has_one_association_2_new_name" + DB.Set("gorm:association_autoupdate", true).Save(&user) + + if DB.Where("name = ? AND user_id = ?", "auto_save_has_one_association_new_name", user.ID).First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + user2 := User{Name: "jinzhu_2", Company: Company{Name: "auto_save_has_one_association_2"}} + DB.Set("gorm:association_autocreate", true).Save(&user2) + if DB.Where("name = ?", "auto_save_has_one_association_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_has_one_association_2 should been created when autocreate is true") + } +} + +func TestAutoSaveMany2ManyAssociation(t *testing.T) { + type Company struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Name string + Companies []Company `gorm:"many2many:user_companies;association_autoupdate:false;association_autocreate:false;"` + } + + DB.AutoMigrate(&Company{}, &User{}) + + DB.Save(&User{Name: "jinzhu", Companies: []Company{{Name: "auto_save_m2m_association"}}}) + + if !DB.Where("name = ?", "auto_save_m2m_association").First(&Company{}).RecordNotFound() { + t.Errorf("Company auto_save_m2m_association should not have been saved when autosave is false") + } + + company := Company{Name: "auto_save_m2m_association"} + DB.Save(&company) + + company.Name = "auto_save_m2m_association_new_name" + user := User{Name: "jinzhu", Companies: []Company{company, {Name: "auto_save_m2m_association_new_name_2"}}} + + DB.Save(&user) + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not have been updated") + } + + if !DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should not been created") + } + + if DB.Model(&user).Association("Companies").Count() != 1 { + t.Errorf("Relationship should been saved") + } + + DB.Set("gorm:association_autoupdate", true).Set("gorm:association_autocreate", true).Save(&user) + + if DB.Where("name = ?", "auto_save_m2m_association_new_name").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been updated") + } + + if DB.Where("name = ?", "auto_save_m2m_association_new_name_2").First(&Company{}).RecordNotFound() { + t.Errorf("Company should been created") + } + + if DB.Model(&user).Association("Companies").Count() != 2 { + t.Errorf("Relationship should been updated") } } diff --git a/callback_save.go b/callback_save.go index f4bc918e..ef267141 100644 --- a/callback_save.go +++ b/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,31 +13,81 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + if v == "false" || v != "skip" { + return false + } + } + + return true + } + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } else if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOUPDATE"]; ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_AUTOCREATE"]; ok { + autoCreate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettings["ASSOCIATION_SAVE_REFERENCE"]; ok { + saveReference = checkTruth(value) } } } - return false, nil + + return } func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && relationship.Kind == "belongs_to" { fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { + // set value's foreign key + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) + } } } } @@ -43,12 +96,10 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { value := field.Field switch value.Kind() { @@ -58,7 +109,41 @@ func saveAfterAssociationsCallback(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + if saveReference { + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } + } + + if relationship.PolymorphicType != "" { + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { + scope.Err(newDB.Save(elem).Error) + } + + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + } + } + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { + if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.FieldByName(associationForeignName); ok { @@ -70,29 +155,15 @@ func saveAfterAssociationsCallback(scope *Scope) { if relationship.PolymorphicType != "" { scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } + } - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) } + } else if autoUpdate { + scope.Err(scope.NewDB().Save(elem).Error) } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - scope.Err(scope.NewDB().Save(elem).Error) } } } diff --git a/callback_update.go b/callback_update.go index 6948439f..373bd726 100644 --- a/callback_update.go +++ b/callback_update.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "sort" "strings" ) @@ -59,7 +60,16 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { diff --git a/customize_column_test.go b/customize_column_test.go index ddb536b8..5e19d6f4 100644 --- a/customize_column_test.go +++ b/customize_column_test.go @@ -279,3 +279,68 @@ func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { t.Errorf("should preload discount from coupon") } } + +type SelfReferencingUser struct { + gorm.Model + Name string + Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;association_jointable_foreignkey:friend_id"` +} + +func TestSelfReferencingMany2ManyColumn(t *testing.T) { + DB.DropTable(&SelfReferencingUser{}, "UserFriends") + DB.AutoMigrate(&SelfReferencingUser{}) + + friend1 := SelfReferencingUser{Name: "friend1_m2m"} + if err := DB.Create(&friend1).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + friend2 := SelfReferencingUser{Name: "friend2_m2m"} + if err := DB.Create(&friend2).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + user := SelfReferencingUser{ + Name: "self_m2m", + Friends: []*SelfReferencingUser{&friend1, &friend2}, + } + + if err := DB.Create(&user).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if DB.Model(&user).Association("Friends").Count() != 2 { + t.Errorf("Should find created friends correctly") + } + + var newUser = SelfReferencingUser{} + + if err := DB.Preload("Friends").First(&newUser, "id = ?", user.ID).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if len(newUser.Friends) != 2 { + t.Errorf("Should preload created frineds for self reference m2m") + } + + DB.Model(&newUser).Association("Friends").Append(&SelfReferencingUser{Name: "friend3_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 3 { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Replace(&SelfReferencingUser{Name: "friend4_m2m"}) + if DB.Model(&user).Association("Friends").Count() != 1 { + t.Errorf("Should find created friends correctly") + } + + friend := SelfReferencingUser{} + DB.Model(&newUser).Association("Friends").Find(&friend) + if friend.Name != "friend4_m2m" { + t.Errorf("Should find created friends correctly") + } + + DB.Model(&newUser).Association("Friends").Delete(friend) + if DB.Model(&user).Association("Friends").Count() != 0 { + t.Errorf("All friends should be deleted") + } +} diff --git a/dialect_common.go b/dialect_common.go index 1e5e3b61..fbbaef33 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -40,7 +40,7 @@ func (commonDialect) Quote(key string) string { func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - return value != "FALSE" + return strings.ToLower(value) != "false" } return field.IsPrimaryKey } diff --git a/join_table_handler.go b/join_table_handler.go index 2d1a5055..f07541ba 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -82,38 +82,40 @@ func (s JoinTableHandler) Table(db *DB) string { return s.TableName } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) + + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) + + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -141,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/main.go b/main.go index b23ae2f2..fc4859ac 100644 --- a/main.go +++ b/main.go @@ -274,7 +274,7 @@ func (s *DB) Assign(attrs ...interface{}) *DB { // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -282,7 +282,7 @@ func (s *DB) First(out interface{}, where ...interface{}) *DB { // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -290,12 +290,12 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -311,8 +311,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -337,7 +337,7 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) @@ -377,7 +377,7 @@ func (s *DB) Update(attrs ...interface{}) *DB { // Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db @@ -390,7 +390,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { // UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -399,7 +399,7 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) if !scope.PrimaryKeyZero() { newDB := scope.callCallbacks(s.parent.callbacks.updates).db if newDB.Error == nil && newDB.RowsAffected == 0 { @@ -412,13 +412,13 @@ func (s *DB) Save(value interface{}) *DB { // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -429,7 +429,7 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) + scope := s.NewScope(nil) generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) @@ -495,7 +495,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -544,7 +544,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -570,14 +570,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -598,7 +598,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -606,7 +606,7 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } diff --git a/migration_test.go b/migration_test.go index 6b4470a6..d58e1fb5 100644 --- a/migration_test.go +++ b/migration_test.go @@ -33,6 +33,7 @@ type User struct { CompanyID *int Company Company Role Role + Password EncryptedData PasswordHash []byte IgnoreMe int64 `sql:"-"` IgnoreStringSlice []string `sql:"-"` @@ -116,6 +117,31 @@ type Company struct { Owner *User `sql:"-"` } +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + type Role struct { Name string `gorm:"size:256"` } diff --git a/model_struct.go b/model_struct.go index 315028c4..f571e2e8 100644 --- a/model_struct.go +++ b/model_struct.go @@ -249,11 +249,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -264,37 +266,65 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey := field.TagSettings["JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) + } + } + + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToDBName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } } } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) + if foreignKey := field.TagSettings["ASSOCIATION_JOINTABLE_FOREIGNKEY"]; foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -399,11 +429,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { ) if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey := field.TagSettings["ASSOCIATION_FOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/query_test.go b/query_test.go index def84e04..135805a7 100644 --- a/query_test.go +++ b/query_test.go @@ -2,6 +2,7 @@ package gorm_test import ( "fmt" + "os" "reflect" "github.com/jinzhu/gorm" @@ -389,7 +390,7 @@ func TestOffset(t *testing.T) { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") @@ -430,6 +431,15 @@ func TestCount(t *testing.T) { if count1 != 1 || count2 != 3 { t.Errorf("Multiple count in chain") } + + var count3 int + if err := DB.Model(&User{}).Where("name in (?)", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("Not error should happen, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count, but got %v", count3) + } } func TestNot(t *testing.T) { @@ -665,3 +675,39 @@ func TestSelectWithArrayInput(t *testing.T) { t.Errorf("Should have selected both age and name") } } + +func TestPluckWithSelect(t *testing.T) { + var ( + user = User{Name: "matematik7_pluck_with_select", Age: 25} + combinedName = fmt.Sprintf("%v%v", user.Name, user.Age) + combineUserAgeSQL = fmt.Sprintf("concat(%v, %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + ) + + if dialect := os.Getenv("GORM_DIALECT"); dialect == "sqlite" { + combineUserAgeSQL = fmt.Sprintf("(%v || %v)", DB.Dialect().Quote("name"), DB.Dialect().Quote("age")) + } + + DB.Save(&user) + + selectStr := combineUserAgeSQL + " as user_age" + var userAges []string + err := DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != combinedName { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } + + selectStr = combineUserAgeSQL + fmt.Sprintf(" as %v", DB.Dialect().Quote("user_age")) + userAges = userAges[:0] + err = DB.Model(&User{}).Where("age = ?", 25).Select(selectStr).Pluck("user_age", &userAges).Error + if err != nil { + t.Error(err) + } + + if len(userAges) != 1 || userAges[0] != combinedName { + t.Errorf("Should correctly pluck with select, got: %s", userAges) + } +} diff --git a/scope.go b/scope.go index 0850de00..9b5da837 100644 --- a/scope.go +++ b/scope.go @@ -558,9 +558,13 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { + var err error switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + replacements = append(replacements, scope.AddToVars(arg)) + } else if bytes, ok := arg.([]byte); ok { replacements = append(replacements, scope.AddToVars(bytes)) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string @@ -573,11 +577,14 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() + arg, err = valuer.Value() } replacements = append(replacements, scope.AddToVars(arg)) } + if err != nil { + scope.Err(err) + } } buff := gobytes.NewBuffer([]byte{}) @@ -644,9 +651,13 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string args := clause["args"].([]interface{}) for _, arg := range args { + var err error switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + } else if bytes, ok := arg.([]byte); ok { str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string @@ -659,10 +670,13 @@ func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string } default: if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() + arg, err = scanner.Value() } str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) } + if err != nil { + scope.Err(err) + } } return } @@ -954,14 +968,34 @@ func (scope *Scope) initialize() *Scope { return scope } +func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { + queryStr := strings.ToLower(fmt.Sprint(query)) + if queryStr == column { + return true + } + + if strings.HasSuffix(queryStr, "as "+column) { + return true + } + + if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { + return true + } + + return false +} + func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } + if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { + scope.Search.Select(column) + } + rows, err := scope.rows() if scope.Err(err) == nil { defer rows.Close() @@ -980,7 +1014,12 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { func (scope *Scope) count(value interface{}) *Scope { if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - scope.Search.Select("count(*)") + if len(scope.Search.group) != 0 { + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" + } else { + scope.Search.Select("count(*)") + } } scope.Search.ignoreOrderQuery = true scope.Err(scope.row().Scan(value)) @@ -1023,18 +1062,6 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) tx := scope.db.Set("gorm:association:source", scope.Value) diff --git a/scope_test.go b/scope_test.go index 42458995..3018f350 100644 --- a/scope_test.go +++ b/scope_test.go @@ -1,8 +1,12 @@ package gorm_test import ( - "github.com/jinzhu/gorm" + "encoding/hex" + "math/rand" + "strings" "testing" + + "github.com/jinzhu/gorm" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -41,3 +45,36 @@ func TestScopes(t *testing.T) { t.Errorf("Should found two users's name in 1, 3") } } + +func randName() string { + data := make([]byte, 8) + rand.Read(data) + + return "n-" + hex.EncodeToString(data) +} + +func TestValuer(t *testing.T) { + name := randName() + + origUser := User{Name: name, Age: 1, Password: EncryptedData("pass1"), PasswordHash: []byte("abc")} + if err := DB.Save(&origUser).Error; err != nil { + t.Errorf("No error should happen when saving user, but got %v", err) + } + + var user2 User + if err := DB.Where("name = ? AND password = ? AND password_hash = ?", name, EncryptedData("pass1"), []byte("abc")).First(&user2).Error; err != nil { + t.Errorf("No error should happen when querying user with valuer, but got %v", err) + } +} + +func TestFailedValuer(t *testing.T) { + name := randName() + + err := DB.Exec("INSERT INTO users(name, password) VALUES(?, ?)", name, EncryptedData("xpass1")).Error + + if err == nil { + t.Errorf("There should be an error should happen when insert data") + } else if !strings.HasPrefix(err.Error(), "Should not start with") { + t.Errorf("The error should be returned from Valuer, but get %v", err) + } +}