Add tests for sub model
This commit is contained in:
		
							parent
							
								
									67de7a8af8
								
							
						
					
					
						commit
						4e34a6d21b
					
				| @ -458,20 +458,12 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { | |||||||
| 	case len(field.StructField.Index) == 1 && fieldIndex >= 0: | 	case len(field.StructField.Index) == 1 && fieldIndex >= 0: | ||||||
| 		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { | 		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { | ||||||
| 			v = reflect.Indirect(v) | 			v = reflect.Indirect(v) | ||||||
| 			if v.Type() != modelType { |  | ||||||
| 				fieldValue := v.FieldByName(field.Name) |  | ||||||
| 				return fieldValue.Interface(), fieldValue.IsZero() |  | ||||||
| 			} |  | ||||||
| 			fieldValue := v.Field(fieldIndex) | 			fieldValue := v.Field(fieldIndex) | ||||||
| 			return fieldValue.Interface(), fieldValue.IsZero() | 			return fieldValue.Interface(), fieldValue.IsZero() | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| 		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { | 		field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { | ||||||
| 			v = reflect.Indirect(v) | 			v = reflect.Indirect(v) | ||||||
| 			if v.Type() != modelType { |  | ||||||
| 				fieldValue := v.FieldByName(field.Name) |  | ||||||
| 				return fieldValue.Interface(), fieldValue.IsZero() |  | ||||||
| 			} |  | ||||||
| 			for _, fieldIdx := range field.StructField.Index { | 			for _, fieldIdx := range field.StructField.Index { | ||||||
| 				if fieldIdx >= 0 { | 				if fieldIdx >= 0 { | ||||||
| 					v = v.Field(fieldIdx) | 					v = v.Field(fieldIdx) | ||||||
| @ -516,17 +508,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) { | |||||||
| 	case len(field.StructField.Index) == 1 && fieldIndex >= 0: | 	case len(field.StructField.Index) == 1 && fieldIndex >= 0: | ||||||
| 		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { | 		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { | ||||||
| 			v = reflect.Indirect(v) | 			v = reflect.Indirect(v) | ||||||
| 			if v.Type() != modelType { |  | ||||||
| 				return v.FieldByName(field.Name) |  | ||||||
| 			} |  | ||||||
| 			return v.Field(fieldIndex) | 			return v.Field(fieldIndex) | ||||||
| 		} | 		} | ||||||
| 	default: | 	default: | ||||||
| 		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { | 		field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { | ||||||
| 			v = reflect.Indirect(v) | 			v = reflect.Indirect(v) | ||||||
| 			if v.Type() != modelType { |  | ||||||
| 				return v.FieldByName(field.Name) |  | ||||||
| 			} |  | ||||||
| 			for idx, fieldIdx := range field.StructField.Index { | 			for idx, fieldIdx := range field.StructField.Index { | ||||||
| 				if fieldIdx >= 0 { | 				if fieldIdx >= 0 { | ||||||
| 					v = v.Field(fieldIdx) | 					v = v.Field(fieldIdx) | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								statement.go
									
									
									
									
									
								
							| @ -658,12 +658,15 @@ func (stmt *Statement) Changed(fields ...string) bool { | |||||||
| 				for destValue.Kind() == reflect.Ptr { | 				for destValue.Kind() == reflect.Ptr { | ||||||
| 					destValue = destValue.Elem() | 					destValue = destValue.Elem() | ||||||
| 				} | 				} | ||||||
| 
 | 				if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { | ||||||
| 				changedValue, zero := field.ValueOf(stmt.Context, destValue) | 					if destField := descSchema.LookUpField(field.DBName); destField != nil { | ||||||
| 				if v { | 						changedValue, zero := destField.ValueOf(stmt.Context, destValue) | ||||||
| 					return !utils.AssertEqual(changedValue, fieldValue) | 						if v { | ||||||
|  | 							return !utils.AssertEqual(changedValue, fieldValue) | ||||||
|  | 						} | ||||||
|  | 						return !zero && !utils.AssertEqual(changedValue, fieldValue) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 				return !zero && !utils.AssertEqual(changedValue, fieldValue) |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		return false | 		return false | ||||||
|  | |||||||
| @ -1,88 +0,0 @@ | |||||||
| package tests_test |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"gorm.io/gorm" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| type Man struct { |  | ||||||
| 	ID     int |  | ||||||
| 	Age    int |  | ||||||
| 	Name   string |  | ||||||
| 	Detail string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Panic-safe BeforeUpdate hook that checks for Changed("age")
 |  | ||||||
| func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) { |  | ||||||
| 	defer func() { |  | ||||||
| 		if r := recover(); r != nil { |  | ||||||
| 			err = fmt.Errorf("panic in BeforeUpdate: %v", r) |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
| 
 |  | ||||||
| 	if !tx.Statement.Changed("age") { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *Man) update(data interface{}) error { |  | ||||||
| 	return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestBeforeUpdateStatementChanged(t *testing.T) { |  | ||||||
| 	DB.AutoMigrate(&Man{}) |  | ||||||
| 	type TestCase struct { |  | ||||||
| 		BaseObjects Man |  | ||||||
| 		change      interface{} |  | ||||||
| 		expectError bool |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	testCases := []TestCase{ |  | ||||||
| 		{ |  | ||||||
| 			BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"}, |  | ||||||
| 			change: struct { |  | ||||||
| 				Age int |  | ||||||
| 			}{Age: 20}, |  | ||||||
| 			expectError: false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"}, |  | ||||||
| 			change: struct { |  | ||||||
| 				Name string |  | ||||||
| 			}{Name: "name-only"}, |  | ||||||
| 			expectError: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"}, |  | ||||||
| 			change: struct { |  | ||||||
| 				Name string |  | ||||||
| 				Age int |  | ||||||
| 			}{Name: "name-only", Age: 20}, |  | ||||||
| 			expectError: false, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, test := range testCases { |  | ||||||
| 		DB.Create(&test.BaseObjects) |  | ||||||
| 
 |  | ||||||
| 		// below comment is stored for future reference
 |  | ||||||
| 		// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
 |  | ||||||
| 		err := test.BaseObjects.update(test.change) |  | ||||||
| 		if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") { |  | ||||||
| 			if !test.expectError { |  | ||||||
| 				t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err) |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			if test.expectError { |  | ||||||
| 				t.Errorf("expected panic did not occur for input: %+v", test.change) |  | ||||||
| 			} |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Errorf("unexpected GORM error: %v", err) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -24,7 +24,7 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMigrate(t *testing.T) { | func TestMigrate(t *testing.T) { | ||||||
| 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} | 	allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}, &Man{}} | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | 	rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) | ||||||
| 	DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") | 	DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") | ||||||
|  | |||||||
							
								
								
									
										45
									
								
								tests/submodel_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								tests/submodel_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,45 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Man struct { | ||||||
|  | 	ID     int | ||||||
|  | 	Age    int | ||||||
|  | 	Name   string | ||||||
|  | 	Detail string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Panic-safe BeforeUpdate hook that checks for Changed("age")
 | ||||||
|  | func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) { | ||||||
|  | 	if !tx.Statement.Changed("age") { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSubModel(t *testing.T) { | ||||||
|  | 	man := Man{Age: 18, Name: "random-name"} | ||||||
|  | 	if err := DB.Create(&man).Error; err != nil { | ||||||
|  | 		t.Fatalf("unexpected error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&man).Where("id = ?", man.ID).Updates(struct { | ||||||
|  | 		Age int | ||||||
|  | 	}{Age: 20}).Error; err != nil { | ||||||
|  | 		t.Fatalf("unexpected error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var result = struct{ | ||||||
|  | 		ID int | ||||||
|  | 		Age int | ||||||
|  | 	}{} | ||||||
|  | 	if err := DB.Model(&man).Where("id = ?", man.ID).Find(&result).Error; err != nil { | ||||||
|  | 		t.Fatalf("unexpected error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.ID != man.ID || result.Age != 20 { | ||||||
|  | 		t.Fatalf("expected ID %d and Age 20, got ID %d and age", result.ID, result.Age) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu