Add Slice Association for BelongsTo
This commit is contained in:
		
							parent
							
								
									91a695893c
								
							
						
					
					
						commit
						2db33730b6
					
				| @ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | ||||
| 			if clear && len(values) == 0 { | ||||
| 				for i := 0; i < reflectValue.Len(); i++ { | ||||
| 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||
| 					for _, ref := range association.Relationship.References { | ||||
| 						if !ref.OwnPrimaryKey { | ||||
| 							ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				break | ||||
| 			} | ||||
| @ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | ||||
| 	case reflect.Struct: | ||||
| 		if clear && len(values) == 0 { | ||||
| 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||
| 			for _, ref := range association.Relationship.References { | ||||
| 				if !ref.OwnPrimaryKey { | ||||
| 					ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		for idx, value := range values { | ||||
| @ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | ||||
| 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(values) > 0 { | ||||
| 		if hasZero { | ||||
| 		association.DB.Save(reflectValue.Addr().Interface()) | ||||
| 			association.DB.Create(reflectValue.Addr().Interface()) | ||||
| 		} else { | ||||
| 		association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) | ||||
| 			association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	for _, assignBack := range assignBacks { | ||||
|  | ||||
| @ -173,12 +173,30 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.Dest != stmt.Model { | ||||
| 		reflectValue := reflect.ValueOf(stmt.Model) | ||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			var priamryKeyExprs []clause.Expression | ||||
| 			for i := 0; i < reflectValue.Len(); i++ { | ||||
| 				var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||
| 				var notZero bool | ||||
| 				for idx, field := range stmt.Schema.PrimaryFields { | ||||
| 					value, isZero := field.ValueOf(reflectValue.Index(i)) | ||||
| 					exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||
| 					notZero = notZero || !isZero | ||||
| 				} | ||||
| 				if notZero { | ||||
| 					priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) | ||||
| 				} | ||||
| 			} | ||||
| 			stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) | ||||
| 		case reflect.Struct: | ||||
| 			for _, field := range stmt.Schema.PrimaryFields { | ||||
| 				if value, isZero := field.ValueOf(reflectValue); !isZero { | ||||
| 					stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -19,4 +19,6 @@ var ( | ||||
| 	ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") | ||||
| 	// ErrUnsupportedRelation unsupported relations
 | ||||
| 	ErrUnsupportedRelation = errors.New("unsupported relations") | ||||
| 	// ErrPtrStructSupported only ptr of struct supported
 | ||||
| 	ErrPtrStructSupported = errors.New("only ptr of struct supported") | ||||
| ) | ||||
|  | ||||
| @ -23,7 +23,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| 
 | ||||
| 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||
| 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||
| 		reflectValue := reflect.ValueOf(value) | ||||
| 		reflectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||
| 		switch reflectValue.Kind() { | ||||
| 		case reflect.Slice, reflect.Array: | ||||
| 			tx.AddError(ErrPtrStructSupported) | ||||
| 		case reflect.Struct: | ||||
| 			for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||
| 				if pv, isZero := pf.ValueOf(reflectValue); isZero { | ||||
| 					tx.callbacks.Create().Execute(tx) | ||||
| @ -31,6 +35,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		tx.Statement.AddClause(where) | ||||
| 	} | ||||
|  | ||||
| @ -6,7 +6,26 @@ import ( | ||||
| 	. "github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestAssociationForBelongsTo(t *testing.T) { | ||||
| func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | ||||
| 	if count := DB.Model(data).Association(name).Count(); count != result { | ||||
| 		t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 	} | ||||
| 
 | ||||
| 	var newUser User | ||||
| 	if user, ok := data.(User); ok { | ||||
| 		DB.Find(&newUser, "id = ?", user.ID) | ||||
| 	} else if user, ok := data.(*User); ok { | ||||
| 		DB.Find(&newUser, "id = ?", user.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	if newUser.ID != 0 { | ||||
| 		if count := DB.Model(&newUser).Association(name).Count(); count != result { | ||||
| 			t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestBelongsToAssociation(t *testing.T) { | ||||
| 	var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) | ||||
| 
 | ||||
| 	if err := DB.Create(&user).Error; err != nil { | ||||
| @ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 	CheckUser(t, user2, user) | ||||
| 
 | ||||
| 	// Count
 | ||||
| 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | ||||
| 		t.Errorf("invalid company count, got %v", count) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { | ||||
| 		t.Errorf("invalid manager count, got %v", count) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user, "Company", 1, "") | ||||
| 	AssertAssociationCount(t, user, "Manager", 1, "") | ||||
| 
 | ||||
| 	// Append
 | ||||
| 	var company = Company{Name: "company-belongs-to-append"} | ||||
| @ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 	user.ManagerID = &manager.ID | ||||
| 	CheckUser(t, user2, user) | ||||
| 
 | ||||
| 	AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") | ||||
| 	AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") | ||||
| 
 | ||||
| 	// Replace
 | ||||
| 	var company2 = Company{Name: "company-belongs-to-replace"} | ||||
| 	var manager2 = GetUser("manager-belongs-to-replace", Config{}) | ||||
| @ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 	user.ManagerID = &manager2.ID | ||||
| 	CheckUser(t, user2, user) | ||||
| 
 | ||||
| 	AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") | ||||
| 	AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") | ||||
| 
 | ||||
| 	// Delete
 | ||||
| 	if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { | ||||
| 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { | ||||
| 		t.Errorf("Invalid company count after delete non-existing association, got %v", count) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") | ||||
| 
 | ||||
| 	if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { | ||||
| 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { | ||||
| 		t.Errorf("Invalid company count after delete, got %v", count) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user2, "Company", 0, "after delete") | ||||
| 
 | ||||
| 	if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { | ||||
| 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { | ||||
| 		t.Errorf("Invalid manager count after delete non-existing association, got %v", count) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") | ||||
| 
 | ||||
| 	if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { | ||||
| 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user2, "Manager", 0, "after delete") | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | ||||
| 		t.Errorf("Invalid manager count after delete, got %v", count) | ||||
| 	} | ||||
| 
 | ||||
| 	// Prepare Data
 | ||||
| 	// Prepare Data for Clear
 | ||||
| 	if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { | ||||
| 		t.Fatalf("Error happened when append Company, got %v", err) | ||||
| 	} | ||||
| @ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 		t.Fatalf("Error happened when append Manager, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { | ||||
| 		t.Errorf("Invalid company count after append, got %v", count) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { | ||||
| 		t.Errorf("Invalid manager count after append, got %v", count) | ||||
| 	} | ||||
| 	AssertAssociationCount(t, user2, "Company", 1, "after prepare data") | ||||
| 	AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") | ||||
| 
 | ||||
| 	// Clear
 | ||||
| 	if err := DB.Model(&user2).Association("Company").Clear(); err != nil { | ||||
| @ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) { | ||||
| 		t.Errorf("Error happened when clear Manager, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { | ||||
| 		t.Errorf("Invalid company count after clear, got %v", count) | ||||
| 	AssertAssociationCount(t, user2, "Company", 0, "after clear") | ||||
| 	AssertAssociationCount(t, user2, "Manager", 0, "after clear") | ||||
| } | ||||
| 
 | ||||
| func TestBelongsToAssociationForSlice(t *testing.T) { | ||||
| 	var users = []User{ | ||||
| 		*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), | ||||
| 		*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), | ||||
| 		*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), | ||||
| 	} | ||||
| 
 | ||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | ||||
| 		t.Errorf("Invalid manager count after clear, got %v", count) | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	AssertAssociationCount(t, users, "Company", 3, "") | ||||
| 	AssertAssociationCount(t, users, "Manager", 2, "") | ||||
| 
 | ||||
| 	// Find
 | ||||
| 	var companies []Company | ||||
| 	if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { | ||||
| 		t.Errorf("companies count should be %v, but got %v", 3, len(companies)) | ||||
| 	} | ||||
| 
 | ||||
| 	var managers []User | ||||
| 	if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { | ||||
| 		t.Errorf("managers count should be %v, but got %v", 2, len(managers)) | ||||
| 	} | ||||
| 
 | ||||
| 	// Append
 | ||||
| 
 | ||||
| 	// Replace
 | ||||
| 
 | ||||
| 	// Delete
 | ||||
| 
 | ||||
| 	// Clear
 | ||||
| 	DB.Model(&users).Association("Company").Clear() | ||||
| 	AssertAssociationCount(t, users, "Company", 0, "After Clear") | ||||
| 
 | ||||
| 	DB.Model(&users).Association("Manager").Clear() | ||||
| 	AssertAssociationCount(t, users, "Manager", 0, "After Clear") | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu