Test Hooks For Slice
This commit is contained in:
		
							parent
							
								
									f5566288de
								
							
						
					
					
						commit
						929c0c576c
					
				| @ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { | |||||||
| 	if called := fc(db.Statement.Dest, tx); !called { | 	if called := fc(db.Statement.Dest, tx); !called { | ||||||
| 		switch db.Statement.ReflectValue.Kind() { | 		switch db.Statement.ReflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
|  | 			db.Statement.CurDestIndex = 0 | ||||||
| 			for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | 			for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||||
| 				fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) | 				fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) | ||||||
|  | 				db.Statement.CurDestIndex++ | ||||||
| 			} | 			} | ||||||
| 		case reflect.Struct: | 		case reflect.Struct: | ||||||
| 			fc(db.Statement.ReflectValue.Addr().Interface(), tx) | 			fc(db.Statement.ReflectValue.Addr().Interface(), tx) | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								statement.go
									
									
									
									
									
								
							| @ -38,6 +38,7 @@ type Statement struct { | |||||||
| 	SQL                  strings.Builder | 	SQL                  strings.Builder | ||||||
| 	Vars                 []interface{} | 	Vars                 []interface{} | ||||||
| 	NamedVars            []sql.NamedArg | 	NamedVars            []sql.NamedArg | ||||||
|  | 	CurDestIndex         int | ||||||
| 	attrs                []interface{} | 	attrs                []interface{} | ||||||
| 	assigns              []interface{} | 	assigns              []interface{} | ||||||
| } | } | ||||||
| @ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { | |||||||
| 		v[name] = value | 		v[name] = value | ||||||
| 	} else if stmt.Schema != nil { | 	} else if stmt.Schema != nil { | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| 			field.Set(stmt.ReflectValue, value) | 			switch stmt.ReflectValue.Kind() { | ||||||
|  | 			case reflect.Slice, reflect.Array: | ||||||
|  | 				field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) | ||||||
|  | 			case reflect.Struct: | ||||||
|  | 				field.Set(stmt.ReflectValue, value) | ||||||
|  | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			stmt.AddError(ErrInvalidField) | 			stmt.AddError(ErrInvalidField) | ||||||
| 		} | 		} | ||||||
| @ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool { | |||||||
| 		modelValue = modelValue.Elem() | 		modelValue = modelValue.Elem() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	switch modelValue.Kind() { | ||||||
|  | 	case reflect.Slice, reflect.Array: | ||||||
|  | 		modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) | 	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) | ||||||
| 	changed := func(field *schema.Field) bool { | 	changed := func(field *schema.Field) bool { | ||||||
| 		fieldValue, isZero := field.ValueOf(modelValue) | 		fieldValue, _ := field.ValueOf(modelValue) | ||||||
| 		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | 		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
| 			if v, ok := stmt.Dest.(map[string]interface{}); ok { | 			if v, ok := stmt.Dest.(map[string]interface{}); ok { | ||||||
| 				if fv, ok := v[field.Name]; ok { | 				if fv, ok := v[field.Name]; ok { | ||||||
| 					return !utils.AssertEqual(fv, fieldValue) | 					return !utils.AssertEqual(fv, fieldValue) | ||||||
| 				} else if fv, ok := v[field.DBName]; ok { | 				} else if fv, ok := v[field.DBName]; ok { | ||||||
| 					return !utils.AssertEqual(fv, fieldValue) | 					return !utils.AssertEqual(fv, fieldValue) | ||||||
| 				} else if isZero { |  | ||||||
| 					return true |  | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| 				changedValue, _ := field.ValueOf(stmt.ReflectValue) | 				changedValue, _ := field.ValueOf(stmt.ReflectValue) | ||||||
|  | |||||||
| @ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	AssertEqual(t, result2, product) | 	AssertEqual(t, result2, product) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestHooksForSlice(t *testing.T) { | ||||||
|  | 	products := []*Product3{ | ||||||
|  | 		{Name: "Product-1", Price: 100}, | ||||||
|  | 		{Name: "Product-2", Price: 200}, | ||||||
|  | 		{Name: "Product-3", Price: 300}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Create(&products) | ||||||
|  | 
 | ||||||
|  | 	for idx, value := range []int64{200, 300, 400} { | ||||||
|  | 		if products[idx].Price != value { | ||||||
|  | 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&products).Update("Name", "product-name") | ||||||
|  | 
 | ||||||
|  | 	// will set all product's price to last product's price + 10
 | ||||||
|  | 	for idx, value := range []int64{410, 410, 410} { | ||||||
|  | 		if products[idx].Price != value { | ||||||
|  | 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	products2 := []Product3{ | ||||||
|  | 		{Name: "Product-1", Price: 100}, | ||||||
|  | 		{Name: "Product-2", Price: 200}, | ||||||
|  | 		{Name: "Product-3", Price: 300}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Create(&products2) | ||||||
|  | 
 | ||||||
|  | 	for idx, value := range []int64{200, 300, 400} { | ||||||
|  | 		if products2[idx].Price != value { | ||||||
|  | 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&products2).Update("Name", "product-name") | ||||||
|  | 
 | ||||||
|  | 	// will set all product's price to last product's price + 10
 | ||||||
|  | 	for idx, value := range []int64{410, 410, 410} { | ||||||
|  | 		if products2[idx].Price != value { | ||||||
|  | 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu