Allow to omit fields in associations, close #3752
This commit is contained in:
		
							parent
							
								
									50df9da6a1
								
							
						
					
					
						commit
						54b80b18bc
					
				| @ -2,6 +2,7 @@ package callbacks | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| @ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				if elems.Len() > 0 { | 				if elems.Len() > 0 { | ||||||
| 					if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { | 					if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) == nil { | ||||||
| 						for i := 0; i < elems.Len(); i++ { | 						for i := 0; i < elems.Len(); i++ { | ||||||
| 							setupReferences(objs[i], elems.Index(i)) | 							setupReferences(objs[i], elems.Index(i)) | ||||||
| 						} | 						} | ||||||
| @ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { | |||||||
| 						rv = rv.Addr() | 						rv = rv.Addr() | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					if db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { | 					if saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), rv.Interface()) == nil { | ||||||
| 						setupReferences(db.Statement.ReflectValue, rv) | 						setupReferences(db.Statement.ReflectValue, rv) | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| @ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( | 					saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) | ||||||
| 						onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), |  | ||||||
| 					).Create(elems.Interface()).Error) |  | ||||||
| 				} | 				} | ||||||
| 			case reflect.Struct: | 			case reflect.Struct: | ||||||
| 				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { | 				if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { | ||||||
| @ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( | 					saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), f.Interface()) | ||||||
| 						onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), |  | ||||||
| 					).Create(f.Interface()).Error) |  | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 					assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | 					assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses( | 				saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), elems.Interface()) | ||||||
| 					onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), |  | ||||||
| 				).Create(elems.Interface()).Error) |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -292,7 +287,7 @@ func SaveAfterAssociations(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 			if elems.Len() > 0 { | 			if elems.Len() > 0 { | ||||||
| 				if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | 				if v, ok := selectColumns[rel.Name+".*"]; !ok || v { | ||||||
| 					db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) | 					saveAssociations(db, selectColumns, rel.Name, onConflictOption(db.Statement, rel.FieldSchema, nil), elems.Interface()) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				for i := 0; i < elems.Len(); i++ { | 				for i := 0; i < elems.Len(); i++ { | ||||||
| @ -335,3 +330,37 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol | |||||||
| 
 | 
 | ||||||
| 	return clause.OnConflict{DoNothing: true} | 	return clause.OnConflict{DoNothing: true} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func saveAssociations(db *gorm.DB, selectColumns map[string]bool, refName string, onConflict clause.OnConflict, values interface{}) error { | ||||||
|  | 	var selects, omits []string | ||||||
|  | 	refName = refName + "." | ||||||
|  | 
 | ||||||
|  | 	for name, ok := range selectColumns { | ||||||
|  | 		columnName := "" | ||||||
|  | 		if strings.HasPrefix(name, refName) { | ||||||
|  | 			columnName = strings.TrimPrefix(name, refName) | ||||||
|  | 		} else if strings.HasPrefix(name, clause.Associations) { | ||||||
|  | 			columnName = name | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if columnName != "" { | ||||||
|  | 			if ok { | ||||||
|  | 				selects = append(selects, columnName) | ||||||
|  | 			} else { | ||||||
|  | 				omits = append(omits, columnName) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict) | ||||||
|  | 
 | ||||||
|  | 	if len(selects) > 0 { | ||||||
|  | 		tx = tx.Select(selects) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(omits) > 0 { | ||||||
|  | 		tx = tx.Omit(omits...) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return db.AddError(tx.Create(values).Error) | ||||||
|  | } | ||||||
|  | |||||||
| @ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) { | |||||||
| 	AssertAssociationCount(t, user2, "Account", 0, "after clear") | 	AssertAssociationCount(t, user2, "Account", 0, "after clear") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestHasOneAssociationWithSelect(t *testing.T) { | ||||||
|  | 	var user = *GetUser("hasone", Config{Account: true}) | ||||||
|  | 
 | ||||||
|  | 	DB.Omit("Account.Number").Create(&user) | ||||||
|  | 
 | ||||||
|  | 	AssertAssociationCount(t, user, "Account", 1, "") | ||||||
|  | 
 | ||||||
|  | 	var account Account | ||||||
|  | 	DB.Model(&user).Association("Account").Find(&account) | ||||||
|  | 	if account.Number != "" { | ||||||
|  | 		t.Errorf("account's number should not be saved") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestHasOneAssociationForSlice(t *testing.T) { | func TestHasOneAssociationForSlice(t *testing.T) { | ||||||
| 	var users = []User{ | 	var users = []User{ | ||||||
| 		*GetUser("slice-hasone-1", Config{Account: true}), | 		*GetUser("slice-hasone-1", Config{Account: true}), | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu