Add Slice Association for BelongsTo
This commit is contained in:
		
							parent
							
								
									91a695893c
								
							
						
					
					
						commit
						2db33730b6
					
				| @ -366,6 +366,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 			if clear && len(values) == 0 { | 			if clear && len(values) == 0 { | ||||||
| 				for i := 0; i < reflectValue.Len(); i++ { | 				for i := 0; i < reflectValue.Len(); i++ { | ||||||
| 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | 					association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||||
|  | 					for _, ref := range association.Relationship.References { | ||||||
|  | 						if !ref.OwnPrimaryKey { | ||||||
|  | 							ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()) | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| @ -382,6 +387,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 	case reflect.Struct: | 	case reflect.Struct: | ||||||
| 		if clear && len(values) == 0 { | 		if clear && len(values) == 0 { | ||||||
| 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | 			association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) | ||||||
|  | 			for _, ref := range association.Relationship.References { | ||||||
|  | 				if !ref.OwnPrimaryKey { | ||||||
|  | 					ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for idx, value := range values { | 		for idx, value := range values { | ||||||
| @ -392,10 +402,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ | |||||||
| 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | 		_, hasZero = association.DB.Statement.Schema.PrioritizedPrimaryField.ValueOf(reflectValue) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if hasZero { | 	if len(values) > 0 { | ||||||
| 		association.DB.Save(reflectValue.Addr().Interface()) | 		if hasZero { | ||||||
| 	} else { | 			association.DB.Create(reflectValue.Addr().Interface()) | ||||||
| 		association.DB.Select(selectedColumns).Save(reflectValue.Addr().Interface()) | 		} else { | ||||||
|  | 			association.DB.Select(selectedColumns).Model(nil).Save(reflectValue.Addr().Interface()) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, assignBack := range assignBacks { | 	for _, assignBack := range assignBacks { | ||||||
|  | |||||||
| @ -173,10 +173,28 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if stmt.Dest != stmt.Model { | 	if stmt.Dest != stmt.Model { | ||||||
| 		reflectValue := reflect.ValueOf(stmt.Model) | 		reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Model)) | ||||||
| 		for _, field := range stmt.Schema.PrimaryFields { | 		switch reflectValue.Kind() { | ||||||
| 			if value, isZero := field.ValueOf(reflectValue); !isZero { | 		case reflect.Slice, reflect.Array: | ||||||
| 				stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | 			var priamryKeyExprs []clause.Expression | ||||||
|  | 			for i := 0; i < reflectValue.Len(); i++ { | ||||||
|  | 				var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||||
|  | 				var notZero bool | ||||||
|  | 				for idx, field := range stmt.Schema.PrimaryFields { | ||||||
|  | 					value, isZero := field.ValueOf(reflectValue.Index(i)) | ||||||
|  | 					exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||||
|  | 					notZero = notZero || !isZero | ||||||
|  | 				} | ||||||
|  | 				if notZero { | ||||||
|  | 					priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) | ||||||
|  | 		case reflect.Struct: | ||||||
|  | 			for _, field := range stmt.Schema.PrimaryFields { | ||||||
|  | 				if value, isZero := field.ValueOf(reflectValue); !isZero { | ||||||
|  | 					stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -19,4 +19,6 @@ var ( | |||||||
| 	ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") | 	ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") | ||||||
| 	// ErrUnsupportedRelation unsupported relations
 | 	// ErrUnsupportedRelation unsupported relations
 | ||||||
| 	ErrUnsupportedRelation = errors.New("unsupported relations") | 	ErrUnsupportedRelation = errors.New("unsupported relations") | ||||||
|  | 	// ErrPtrStructSupported only ptr of struct supported
 | ||||||
|  | 	ErrPtrStructSupported = errors.New("only ptr of struct supported") | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -23,12 +23,17 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 
 | 
 | ||||||
| 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | 	if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||||
| 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | 		where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))} | ||||||
| 		reflectValue := reflect.ValueOf(value) | 		reflectValue := reflect.Indirect(reflect.ValueOf(value)) | ||||||
| 		for idx, pf := range tx.Statement.Schema.PrimaryFields { | 		switch reflectValue.Kind() { | ||||||
| 			if pv, isZero := pf.ValueOf(reflectValue); isZero { | 		case reflect.Slice, reflect.Array: | ||||||
| 				tx.callbacks.Create().Execute(tx) | 			tx.AddError(ErrPtrStructSupported) | ||||||
| 				where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} | 		case reflect.Struct: | ||||||
| 				return | 			for idx, pf := range tx.Statement.Schema.PrimaryFields { | ||||||
|  | 				if pv, isZero := pf.ValueOf(reflectValue); isZero { | ||||||
|  | 					tx.callbacks.Create().Execute(tx) | ||||||
|  | 					where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv} | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -6,7 +6,26 @@ import ( | |||||||
| 	. "github.com/jinzhu/gorm/tests" | 	. "github.com/jinzhu/gorm/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestAssociationForBelongsTo(t *testing.T) { | func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { | ||||||
|  | 	if count := DB.Model(data).Association(name).Count(); count != result { | ||||||
|  | 		t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var newUser User | ||||||
|  | 	if user, ok := data.(User); ok { | ||||||
|  | 		DB.Find(&newUser, "id = ?", user.ID) | ||||||
|  | 	} else if user, ok := data.(*User); ok { | ||||||
|  | 		DB.Find(&newUser, "id = ?", user.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if newUser.ID != 0 { | ||||||
|  | 		if count := DB.Model(&newUser).Association(name).Count(); count != result { | ||||||
|  | 			t.Errorf("invalid %v count %v, expects: %v got %v", name, reason, result, count) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestBelongsToAssociation(t *testing.T) { | ||||||
| 	var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) | 	var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Create(&user).Error; err != nil { | 	if err := DB.Create(&user).Error; err != nil { | ||||||
| @ -24,13 +43,8 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
| 
 | 
 | ||||||
| 	// Count
 | 	// Count
 | ||||||
| 	if count := DB.Model(&user).Association("Company").Count(); count != 1 { | 	AssertAssociationCount(t, user, "Company", 1, "") | ||||||
| 		t.Errorf("invalid company count, got %v", count) | 	AssertAssociationCount(t, user, "Manager", 1, "") | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if count := DB.Model(&user).Association("Manager").Count(); count != 1 { |  | ||||||
| 		t.Errorf("invalid manager count, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Append
 | 	// Append
 | ||||||
| 	var company = Company{Name: "company-belongs-to-append"} | 	var company = Company{Name: "company-belongs-to-append"} | ||||||
| @ -58,6 +72,9 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	user.ManagerID = &manager.ID | 	user.ManagerID = &manager.ID | ||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
| 
 | 
 | ||||||
|  | 	AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") | ||||||
|  | 	AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") | ||||||
|  | 
 | ||||||
| 	// Replace
 | 	// Replace
 | ||||||
| 	var company2 = Company{Name: "company-belongs-to-replace"} | 	var company2 = Company{Name: "company-belongs-to-replace"} | ||||||
| 	var manager2 = GetUser("manager-belongs-to-replace", Config{}) | 	var manager2 = GetUser("manager-belongs-to-replace", Config{}) | ||||||
| @ -84,40 +101,31 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 	user.ManagerID = &manager2.ID | 	user.ManagerID = &manager2.ID | ||||||
| 	CheckUser(t, user2, user) | 	CheckUser(t, user2, user) | ||||||
| 
 | 
 | ||||||
|  | 	AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") | ||||||
|  | 	AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") | ||||||
|  | 
 | ||||||
| 	// Delete
 | 	// Delete
 | ||||||
| 	if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { | 	if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { | ||||||
| 		t.Fatalf("Error happened when delete Company, got %v", err) | 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 	AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") | ||||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { |  | ||||||
| 		t.Errorf("Invalid company count after delete non-existing association, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { | 	if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { | ||||||
| 		t.Fatalf("Error happened when delete Company, got %v", err) | 		t.Fatalf("Error happened when delete Company, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 	AssertAssociationCount(t, user2, "Company", 0, "after delete") | ||||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { |  | ||||||
| 		t.Errorf("Invalid company count after delete, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { | 	if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { | ||||||
| 		t.Fatalf("Error happened when delete Manager, got %v", err) | 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 	AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") | ||||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { |  | ||||||
| 		t.Errorf("Invalid manager count after delete non-existing association, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { | 	if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { | ||||||
| 		t.Fatalf("Error happened when delete Manager, got %v", err) | 		t.Fatalf("Error happened when delete Manager, got %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	AssertAssociationCount(t, user2, "Manager", 0, "after delete") | ||||||
| 
 | 
 | ||||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | 	// Prepare Data for Clear
 | ||||||
| 		t.Errorf("Invalid manager count after delete, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Prepare Data
 |  | ||||||
| 	if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { | 	if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { | ||||||
| 		t.Fatalf("Error happened when append Company, got %v", err) | 		t.Fatalf("Error happened when append Company, got %v", err) | ||||||
| 	} | 	} | ||||||
| @ -126,13 +134,8 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 		t.Fatalf("Error happened when append Manager, got %v", err) | 		t.Fatalf("Error happened when append Manager, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 1 { | 	AssertAssociationCount(t, user2, "Company", 1, "after prepare data") | ||||||
| 		t.Errorf("Invalid company count after append, got %v", count) | 	AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 1 { |  | ||||||
| 		t.Errorf("Invalid manager count after append, got %v", count) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Clear
 | 	// Clear
 | ||||||
| 	if err := DB.Model(&user2).Association("Company").Clear(); err != nil { | 	if err := DB.Model(&user2).Association("Company").Clear(); err != nil { | ||||||
| @ -143,11 +146,43 @@ func TestAssociationForBelongsTo(t *testing.T) { | |||||||
| 		t.Errorf("Error happened when clear Manager, got %v", err) | 		t.Errorf("Error happened when clear Manager, got %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if count := DB.Model(&user2).Association("Company").Count(); count != 0 { | 	AssertAssociationCount(t, user2, "Company", 0, "after clear") | ||||||
| 		t.Errorf("Invalid company count after clear, got %v", count) | 	AssertAssociationCount(t, user2, "Manager", 0, "after clear") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestBelongsToAssociationForSlice(t *testing.T) { | ||||||
|  | 	var users = []User{ | ||||||
|  | 		*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), | ||||||
|  | 		*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), | ||||||
|  | 		*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if count := DB.Model(&user2).Association("Manager").Count(); count != 0 { | 	DB.Create(&users) | ||||||
| 		t.Errorf("Invalid manager count after clear, got %v", count) | 
 | ||||||
|  | 	AssertAssociationCount(t, users, "Company", 3, "") | ||||||
|  | 	AssertAssociationCount(t, users, "Manager", 2, "") | ||||||
|  | 
 | ||||||
|  | 	// Find
 | ||||||
|  | 	var companies []Company | ||||||
|  | 	if DB.Model(users).Association("Company").Find(&companies); len(companies) != 3 { | ||||||
|  | 		t.Errorf("companies count should be %v, but got %v", 3, len(companies)) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	var managers []User | ||||||
|  | 	if DB.Model(users).Association("Manager").Find(&managers); len(managers) != 2 { | ||||||
|  | 		t.Errorf("managers count should be %v, but got %v", 2, len(managers)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Append
 | ||||||
|  | 
 | ||||||
|  | 	// Replace
 | ||||||
|  | 
 | ||||||
|  | 	// Delete
 | ||||||
|  | 
 | ||||||
|  | 	// Clear
 | ||||||
|  | 	DB.Model(&users).Association("Company").Clear() | ||||||
|  | 	AssertAssociationCount(t, users, "Company", 0, "After Clear") | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&users).Association("Manager").Clear() | ||||||
|  | 	AssertAssociationCount(t, users, "Manager", 0, "After Clear") | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu