Add more updates test
This commit is contained in:
		
							parent
							
								
									dffc2713f0
								
							
						
					
					
						commit
						1559fe24e5
					
				| @ -86,6 +86,14 @@ func (association *Association) Replace(values ...interface{}) error { | ||||
| 		case schema.BelongsTo: | ||||
| 			if len(values) == 0 { | ||||
| 				updateMap := map[string]interface{}{} | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					for i := 0; i < reflectValue.Len(); i++ { | ||||
| 						rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) | ||||
| 					} | ||||
| 				case reflect.Struct: | ||||
| 					rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) | ||||
| 				} | ||||
| 
 | ||||
| 				for _, ref := range rel.References { | ||||
| 					updateMap[ref.ForeignKey.DBName] = nil | ||||
|  | ||||
| @ -24,6 +24,13 @@ func SaveBeforeAssociations(db *gorm.DB) { | ||||
| 					if !ref.OwnPrimaryKey { | ||||
| 						pv, _ := ref.PrimaryKey.ValueOf(elem) | ||||
| 						ref.ForeignKey.Set(obj, pv) | ||||
| 
 | ||||
| 						if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { | ||||
| 							dest[ref.ForeignKey.DBName] = pv | ||||
| 							if _, ok := dest[rel.Name]; ok { | ||||
| 								dest[rel.Name] = elem.Interface() | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| @ -37,6 +37,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { | ||||
| 
 | ||||
| 	updateCallback := db.Callback().Update() | ||||
| 	updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) | ||||
| 	updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) | ||||
| 	updateCallback.Register("gorm:before_update", BeforeUpdate) | ||||
| 	updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations) | ||||
| 	updateCallback.Register("gorm:update", Update) | ||||
|  | ||||
| @ -37,6 +37,19 @@ func Query(db *gorm.DB) { | ||||
| func BuildQuerySQL(db *gorm.DB) { | ||||
| 	clauseSelect := clause.Select{} | ||||
| 
 | ||||
| 	if db.Statement.ReflectValue.Kind() == reflect.Struct { | ||||
| 		var conds []clause.Expression | ||||
| 		for _, primaryField := range db.Statement.Schema.PrimaryFields { | ||||
| 			if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { | ||||
| 				conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if len(conds) > 0 { | ||||
| 			db.Statement.AddClause(clause.Where{Exprs: conds}) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(db.Statement.Selects) > 0 { | ||||
| 		for _, name := range db.Statement.Selects { | ||||
| 			if db.Statement.Schema == nil { | ||||
|  | ||||
| @ -9,6 +9,25 @@ import ( | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| ) | ||||
| 
 | ||||
| func SetupUpdateReflectValue(db *gorm.DB) { | ||||
| 	if db.Error == nil { | ||||
| 		if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { | ||||
| 			db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) | ||||
| 			for db.Statement.ReflectValue.Kind() == reflect.Ptr { | ||||
| 				db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() | ||||
| 			} | ||||
| 
 | ||||
| 			if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { | ||||
| 				for _, rel := range db.Statement.Schema.Relationships.BelongsTo { | ||||
| 					if _, ok := dest[rel.Name]; ok { | ||||
| 						rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BeforeUpdate(db *gorm.DB) { | ||||
| 	if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { | ||||
| 		tx := db.Session(&gorm.Session{}) | ||||
| @ -114,21 +133,20 @@ func AfterUpdate(db *gorm.DB) { | ||||
| func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 	var ( | ||||
| 		selectColumns, restricted = SelectAndOmitColumns(stmt, false, true) | ||||
| 		reflectModelValue         = reflect.Indirect(reflect.ValueOf(stmt.Model)) | ||||
| 		assignValue               func(field *schema.Field, value interface{}) | ||||
| 	) | ||||
| 
 | ||||
| 	switch reflectModelValue.Kind() { | ||||
| 	switch stmt.ReflectValue.Kind() { | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 		assignValue = func(field *schema.Field, value interface{}) { | ||||
| 			for i := 0; i < reflectModelValue.Len(); i++ { | ||||
| 				field.Set(reflectModelValue.Index(i), value) | ||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||
| 				field.Set(stmt.ReflectValue.Index(i), value) | ||||
| 			} | ||||
| 		} | ||||
| 	case reflect.Struct: | ||||
| 		assignValue = func(field *schema.Field, value interface{}) { | ||||
| 			if reflectModelValue.CanAddr() { | ||||
| 				field.Set(reflectModelValue, value) | ||||
| 			if stmt.ReflectValue.CanAddr() { | ||||
| 				field.Set(stmt.ReflectValue, value) | ||||
| 			} | ||||
| 		} | ||||
| 	default: | ||||
| @ -136,7 +154,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	switch value := stmt.Dest.(type) { | ||||
| 	updatingValue := reflect.ValueOf(stmt.Dest) | ||||
| 	for updatingValue.Kind() == reflect.Ptr { | ||||
| 		updatingValue = updatingValue.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	switch value := updatingValue.Interface().(type) { | ||||
| 	case map[string]interface{}: | ||||
| 		set = make([]clause.Assignment, 0, len(value)) | ||||
| 
 | ||||
| @ -148,8 +171,12 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 
 | ||||
| 		for _, k := range keys { | ||||
| 			if field := stmt.Schema.LookUpField(k); field != nil { | ||||
| 				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 					set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | ||||
| 				if field.DBName != "" { | ||||
| 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 						set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) | ||||
| 						assignValue(field, value[k]) | ||||
| 					} | ||||
| 				} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { | ||||
| 					assignValue(field, value[k]) | ||||
| 				} | ||||
| 			} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { | ||||
| @ -167,13 +194,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 			} | ||||
| 		} | ||||
| 	default: | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		switch updatingValue.Kind() { | ||||
| 		case reflect.Struct: | ||||
| 			set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) | ||||
| 			for _, field := range stmt.Schema.FieldsByDBName { | ||||
| 				if !field.PrimaryKey || (!stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model) { | ||||
| 				if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { | ||||
| 					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 						value, isZero := field.ValueOf(stmt.ReflectValue) | ||||
| 						value, isZero := field.ValueOf(updatingValue) | ||||
| 						if !stmt.DisableUpdateTime { | ||||
| 							if field.AutoUpdateTime > 0 { | ||||
| 								value = stmt.DB.NowFunc() | ||||
| @ -187,7 +214,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||
| 					if value, isZero := field.ValueOf(updatingValue); !isZero { | ||||
| 						stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | ||||
| 					} | ||||
| 				} | ||||
| @ -195,16 +222,15 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if !stmt.ReflectValue.CanAddr() || stmt.Dest != stmt.Model { | ||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) | ||||
| 		switch reflectValue.Kind() { | ||||
| 	if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | ||||
| 		switch stmt.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var priamryKeyExprs []clause.Expression | ||||
| 			for i := 0; i < reflectValue.Len(); i++ { | ||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||
| 				var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||
| 				var notZero bool | ||||
| 				for idx, field := range stmt.Schema.PrimaryFields { | ||||
| 					value, isZero := field.ValueOf(reflectValue.Index(i)) | ||||
| 					value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) | ||||
| 					exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||
| 					notZero = notZero || !isZero | ||||
| 				} | ||||
| @ -215,7 +241,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 			stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) | ||||
| 		case reflect.Struct: | ||||
| 			for _, field := range stmt.Schema.PrimaryFields { | ||||
| 				if value, isZero := field.ValueOf(reflectValue); !isZero { | ||||
| 				if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||
| 					stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| @ -347,6 +347,8 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 					if v.Type().Elem().Kind() == reflect.Struct { | ||||
| 						if !v.IsNil() { | ||||
| 							v = v.Elem() | ||||
| 						} else { | ||||
| 							return nil, true | ||||
| 						} | ||||
| 					} else { | ||||
| 						return nil, true | ||||
|  | ||||
| @ -8,7 +8,7 @@ import ( | ||||
| 
 | ||||
| func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | ||||
| 	if count := DB.Model(data).Association(name).Count(); count != result { | ||||
| 		t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 		t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 	} | ||||
| 
 | ||||
| 	var newUser User | ||||
| @ -20,7 +20,7 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result | ||||
| 
 | ||||
| 	if newUser.ID != 0 { | ||||
| 		if count := DB.Model(&newUser).Association(name).Count(); count != result { | ||||
| 			t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 			t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -28,6 +28,6 @@ func AssertAssociationCount(t *testing.T, data interface{}, name string, result | ||||
| func TestInvalidAssociation(t *testing.T) { | ||||
| 	var user = *GetUser("invalid", Config{Company: true, Manager: true}) | ||||
| 	if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { | ||||
| 		t.Errorf("should return errors for invalid association, but got nil") | ||||
| 		t.Fatalf("should return errors for invalid association, but got nil") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -31,12 +31,14 @@ func TestDelete(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range []User{users[0], users[2]} { | ||||
| 		result = User{} | ||||
| 		if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { | ||||
| 			t.Errorf("no error should returns when query %v, but got %v", user.ID, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, user := range []User{users[0], users[2]} { | ||||
| 		result = User{} | ||||
| 		if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { | ||||
| 			t.Errorf("no error should returns when query %v, but got %v", user.ID, err) | ||||
| 		} | ||||
|  | ||||
| @ -264,10 +264,12 @@ func TestSearchWithEmptyChain(t *testing.T) { | ||||
| 		t.Errorf("Should not raise any error if searching with empty strings") | ||||
| 	} | ||||
| 
 | ||||
| 	result = User{} | ||||
| 	if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { | ||||
| 		t.Errorf("Should not raise any error if searching with empty struct") | ||||
| 	} | ||||
| 
 | ||||
| 	result = User{} | ||||
| 	if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { | ||||
| 		t.Errorf("Should not raise any error if searching with empty map") | ||||
| 	} | ||||
| @ -319,6 +321,7 @@ func TestSearchWithMap(t *testing.T) { | ||||
| 	DB.First(&user, map[string]interface{}{"name": users[0].Name}) | ||||
| 	CheckUser(t, user, users[0]) | ||||
| 
 | ||||
| 	user = User{} | ||||
| 	DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) | ||||
| 	CheckUser(t, user, users[1]) | ||||
| 
 | ||||
|  | ||||
| @ -2,6 +2,8 @@ package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -218,3 +220,304 @@ func TestBlockGlobalUpdate(t *testing.T) { | ||||
| 		t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdate(t *testing.T) { | ||||
| 	user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 
 | ||||
| 	user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	result.Name = user2.Name | ||||
| 	result.Age = 50 | ||||
| 	result.Account = user2.Account | ||||
| 	result.Pets = user2.Pets | ||||
| 	result.Toys = user2.Toys | ||||
| 	result.Company = user2.Company | ||||
| 	result.Manager = user2.Manager | ||||
| 	result.Team = user2.Team | ||||
| 	result.Languages = user2.Languages | ||||
| 	result.Friends = user2.Friends | ||||
| 
 | ||||
| 	DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) | ||||
| 
 | ||||
| 	result.Languages = append(user.Languages, result.Languages...) | ||||
| 	result.Toys = append(user.Toys, result.Toys...) | ||||
| 
 | ||||
| 	sort.Slice(result.Languages, func(i, j int) bool { | ||||
| 		return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result.Toys, func(i, j int) bool { | ||||
| 		return result.Toys[i].ID < result.Toys[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result2.Languages, func(i, j int) bool { | ||||
| 		return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result2.Toys, func(i, j int) bool { | ||||
| 		return result2.Toys[i].ID < result2.Toys[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdateWithMap(t *testing.T) { | ||||
| 	user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 
 | ||||
| 	user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	updateValues := map[string]interface{}{ | ||||
| 		"Name":      user2.Name, | ||||
| 		"Age":       50, | ||||
| 		"Account":   user2.Account, | ||||
| 		"Pets":      user2.Pets, | ||||
| 		"Toys":      user2.Toys, | ||||
| 		"Company":   user2.Company, | ||||
| 		"Manager":   user2.Manager, | ||||
| 		"Team":      user2.Team, | ||||
| 		"Languages": user2.Languages, | ||||
| 		"Friends":   user2.Friends, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&result).Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) | ||||
| 
 | ||||
| 	result.Languages = append(user.Languages, result.Languages...) | ||||
| 	result.Toys = append(user.Toys, result.Toys...) | ||||
| 
 | ||||
| 	sort.Slice(result.Languages, func(i, j int) bool { | ||||
| 		return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result.Toys, func(i, j int) bool { | ||||
| 		return result.Toys[i].ID < result.Toys[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result2.Languages, func(i, j int) bool { | ||||
| 		return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(result2.Toys, func(i, j int) bool { | ||||
| 		return result2.Toys[i].ID < result2.Toys[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdate(t *testing.T) { | ||||
| 	user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 
 | ||||
| 	user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	result.Name = user2.Name | ||||
| 	result.Age = 50 | ||||
| 	result.Account = user2.Account | ||||
| 	result.Pets = user2.Pets | ||||
| 	result.Toys = user2.Toys | ||||
| 	result.Company = user2.Company | ||||
| 	result.Manager = user2.Manager | ||||
| 	result.Team = user2.Team | ||||
| 	result.Languages = user2.Languages | ||||
| 	result.Friends = user2.Friends | ||||
| 
 | ||||
| 	DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) | ||||
| 
 | ||||
| 	result.Pets = append(user.Pets, result.Pets...) | ||||
| 	result.Team = append(user.Team, result.Team...) | ||||
| 	result.Friends = append(user.Friends, result.Friends...) | ||||
| 
 | ||||
| 	sort.Slice(result.Pets, func(i, j int) bool { | ||||
| 		return result.Pets[i].ID < result.Pets[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result.Team, func(i, j int) bool { | ||||
| 		return result.Team[i].ID < result.Team[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result.Friends, func(i, j int) bool { | ||||
| 		return result.Friends[i].ID < result.Friends[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Pets, func(i, j int) bool { | ||||
| 		return result2.Pets[i].ID < result2.Pets[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Team, func(i, j int) bool { | ||||
| 		return result2.Team[i].ID < result2.Team[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Friends, func(i, j int) bool { | ||||
| 		return result2.Friends[i].ID < result2.Friends[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdateWithMap(t *testing.T) { | ||||
| 	user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 
 | ||||
| 	user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	updateValues := map[string]interface{}{ | ||||
| 		"Name":      user2.Name, | ||||
| 		"Age":       50, | ||||
| 		"Account":   user2.Account, | ||||
| 		"Pets":      user2.Pets, | ||||
| 		"Toys":      user2.Toys, | ||||
| 		"Company":   user2.Company, | ||||
| 		"Manager":   user2.Manager, | ||||
| 		"Team":      user2.Team, | ||||
| 		"Languages": user2.Languages, | ||||
| 		"Friends":   user2.Friends, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) | ||||
| 
 | ||||
| 	result.Pets = append(user.Pets, result.Pets...) | ||||
| 	result.Team = append(user.Team, result.Team...) | ||||
| 	result.Friends = append(user.Friends, result.Friends...) | ||||
| 
 | ||||
| 	sort.Slice(result.Pets, func(i, j int) bool { | ||||
| 		return result.Pets[i].ID < result.Pets[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result.Team, func(i, j int) bool { | ||||
| 		return result.Team[i].ID < result.Team[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result.Friends, func(i, j int) bool { | ||||
| 		return result.Friends[i].ID < result.Friends[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Pets, func(i, j int) bool { | ||||
| 		return result2.Pets[i].ID < result2.Pets[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Team, func(i, j int) bool { | ||||
| 		return result2.Team[i].ID < result2.Team[j].ID | ||||
| 	}) | ||||
| 	sort.Slice(result2.Friends, func(i, j int) bool { | ||||
| 		return result2.Friends[i].ID < result2.Friends[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") | ||||
| } | ||||
| 
 | ||||
| func TestSelectWithUpdateColumn(t *testing.T) { | ||||
| 	user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 	DB.Model(&result).Select("Name").UpdateColumns(updateValues) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.First(&result2, user.ID) | ||||
| 
 | ||||
| 	if result2.Name == user.Name || result2.Age != user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestOmitWithUpdateColumn(t *testing.T) { | ||||
| 	user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 	DB.Model(&result).Omit("Name").UpdateColumns(updateValues) | ||||
| 
 | ||||
| 	var result2 User | ||||
| 	DB.First(&result2, user.ID) | ||||
| 
 | ||||
| 	if result2.Name != user.Name || result2.Age == user.Age { | ||||
| 		t.Errorf("Should only update users with name column") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdateColumnsSkipsAssociations(t *testing.T) { | ||||
| 	user := *GetUser("update_column_skips_association", Config{}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	// Update a single field of the user and verify that the changed address is not stored.
 | ||||
| 	newAge := uint(100) | ||||
| 	user.Account.Number = "new_account_number" | ||||
| 	db := DB.Model(&user).UpdateColumns(User{Age: newAge}) | ||||
| 
 | ||||
| 	if db.RowsAffected != 1 { | ||||
| 		t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) | ||||
| 	} | ||||
| 
 | ||||
| 	// Verify that Age now=`newAge`.
 | ||||
| 	result := &User{} | ||||
| 	result.ID = user.ID | ||||
| 	DB.Preload("Account").First(result) | ||||
| 
 | ||||
| 	if result.Age != newAge { | ||||
| 		t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) | ||||
| 	} | ||||
| 
 | ||||
| 	if result.Account.Number != user.Account.Number { | ||||
| 		t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdatesWithBlankValues(t *testing.T) { | ||||
| 	user := *GetUser("updates_with_blank_value", Config{}) | ||||
| 	DB.Save(&user) | ||||
| 
 | ||||
| 	var user2 User | ||||
| 	user2.ID = user.ID | ||||
| 	DB.Model(&user2).Updates(&User{Age: 100}) | ||||
| 
 | ||||
| 	var result User | ||||
| 	DB.First(&result, user.ID) | ||||
| 
 | ||||
| 	if result.Name != user.Name || result.Age != 100 { | ||||
| 		t.Errorf("user's name should not be updated") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestUpdatesTableWithIgnoredValues(t *testing.T) { | ||||
| 	type ElementWithIgnoredField struct { | ||||
| 		Id           int64 | ||||
| 		Value        string | ||||
| 		IgnoredField int64 `gorm:"-"` | ||||
| 	} | ||||
| 	DB.Migrator().DropTable(&ElementWithIgnoredField{}) | ||||
| 	DB.AutoMigrate(&ElementWithIgnoredField{}) | ||||
| 
 | ||||
| 	elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} | ||||
| 	DB.Save(&elem) | ||||
| 
 | ||||
| 	DB.Model(&ElementWithIgnoredField{}). | ||||
| 		Where("id = ?", elem.Id). | ||||
| 		Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) | ||||
| 
 | ||||
| 	var result ElementWithIgnoredField | ||||
| 	if err := DB.First(&result, elem.Id).Error; err != nil { | ||||
| 		t.Errorf("error getting an element from database: %s", err.Error()) | ||||
| 	} | ||||
| 
 | ||||
| 	if result.IgnoredField != 0 { | ||||
| 		t.Errorf("element's ignored field should not be updated") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -3,6 +3,7 @@ package tests | ||||
| import ( | ||||
| 	"database/sql/driver" | ||||
| 	"fmt" | ||||
| 	"go/ast" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| @ -126,6 +127,37 @@ func AssertEqual(t *testing.T, got, expect interface{}) { | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Slice { | ||||
| 			if reflect.ValueOf(expect).Kind() == reflect.Slice { | ||||
| 				if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { | ||||
| 					for i := 0; i < reflect.ValueOf(got).Len(); i++ { | ||||
| 						name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) | ||||
| 						t.Run(name, func(t *testing.T) { | ||||
| 							AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} else { | ||||
| 					name := reflect.ValueOf(got).Type().Elem().Name() | ||||
| 					t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Kind() == reflect.Struct { | ||||
| 			if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { | ||||
| 				for i := 0; i < reflect.ValueOf(got).NumField(); i++ { | ||||
| 					if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { | ||||
| 						field := reflect.ValueOf(got).Field(i) | ||||
| 						t.Run(fieldStruct.Name, func(t *testing.T) { | ||||
| 							AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) | ||||
| 						}) | ||||
| 					} | ||||
| 				} | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { | ||||
| 			got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() | ||||
| 			isEqual() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu