Test Hooks For Slice
This commit is contained in:
		
							parent
							
								
									66dcd7e3ca
								
							
						
					
					
						commit
						3e4dbde920
					
				| @ -11,8 +11,10 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { | ||||
| 	if called := fc(db.Statement.Dest, tx); !called { | ||||
| 		switch db.Statement.ReflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			db.Statement.CurDestIndex = 0 | ||||
| 			for i := 0; i < db.Statement.ReflectValue.Len(); i++ { | ||||
| 				fc(db.Statement.ReflectValue.Index(i).Addr().Interface(), tx) | ||||
| 				fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) | ||||
| 				db.Statement.CurDestIndex++ | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			fc(db.Statement.ReflectValue.Addr().Interface(), tx) | ||||
|  | ||||
							
								
								
									
										17
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								statement.go
									
									
									
									
									
								
							| @ -38,6 +38,7 @@ type Statement struct { | ||||
| 	SQL                  strings.Builder | ||||
| 	Vars                 []interface{} | ||||
| 	NamedVars            []sql.NamedArg | ||||
| 	CurDestIndex         int | ||||
| 	attrs                []interface{} | ||||
| 	assigns              []interface{} | ||||
| } | ||||
| @ -379,7 +380,12 @@ func (stmt *Statement) SetColumn(name string, value interface{}) { | ||||
| 		v[name] = value | ||||
| 	} else if stmt.Schema != nil { | ||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||
| 			field.Set(stmt.ReflectValue, value) | ||||
| 			switch stmt.ReflectValue.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) | ||||
| 			case reflect.Struct: | ||||
| 				field.Set(stmt.ReflectValue, value) | ||||
| 			} | ||||
| 		} else { | ||||
| 			stmt.AddError(ErrInvalidField) | ||||
| 		} | ||||
| @ -395,17 +401,20 @@ func (stmt *Statement) Changed(fields ...string) bool { | ||||
| 		modelValue = modelValue.Elem() | ||||
| 	} | ||||
| 
 | ||||
| 	switch modelValue.Kind() { | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 		modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) | ||||
| 	} | ||||
| 
 | ||||
| 	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) | ||||
| 	changed := func(field *schema.Field) bool { | ||||
| 		fieldValue, isZero := field.ValueOf(modelValue) | ||||
| 		fieldValue, _ := field.ValueOf(modelValue) | ||||
| 		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||
| 			if v, ok := stmt.Dest.(map[string]interface{}); ok { | ||||
| 				if fv, ok := v[field.Name]; ok { | ||||
| 					return !utils.AssertEqual(fv, fieldValue) | ||||
| 				} else if fv, ok := v[field.DBName]; ok { | ||||
| 					return !utils.AssertEqual(fv, fieldValue) | ||||
| 				} else if isZero { | ||||
| 					return true | ||||
| 				} | ||||
| 			} else { | ||||
| 				changedValue, _ := field.ValueOf(stmt.ReflectValue) | ||||
|  | ||||
| @ -366,3 +366,51 @@ func TestSetColumn(t *testing.T) { | ||||
| 
 | ||||
| 	AssertEqual(t, result2, product) | ||||
| } | ||||
| 
 | ||||
| func TestHooksForSlice(t *testing.T) { | ||||
| 	products := []*Product3{ | ||||
| 		{Name: "Product-1", Price: 100}, | ||||
| 		{Name: "Product-2", Price: 200}, | ||||
| 		{Name: "Product-3", Price: 300}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&products) | ||||
| 
 | ||||
| 	for idx, value := range []int64{200, 300, 400} { | ||||
| 		if products[idx].Price != value { | ||||
| 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&products).Update("Name", "product-name") | ||||
| 
 | ||||
| 	// will set all product's price to last product's price + 10
 | ||||
| 	for idx, value := range []int64{410, 410, 410} { | ||||
| 		if products[idx].Price != value { | ||||
| 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	products2 := []Product3{ | ||||
| 		{Name: "Product-1", Price: 100}, | ||||
| 		{Name: "Product-2", Price: 200}, | ||||
| 		{Name: "Product-3", Price: 300}, | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Create(&products2) | ||||
| 
 | ||||
| 	for idx, value := range []int64{200, 300, 400} { | ||||
| 		if products2[idx].Price != value { | ||||
| 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&products2).Update("Name", "product-name") | ||||
| 
 | ||||
| 	// will set all product's price to last product's price + 10
 | ||||
| 	for idx, value := range []int64{410, 410, 410} { | ||||
| 		if products2[idx].Price != value { | ||||
| 			t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu