feat: support embedded preload (#6137)
* feat: support embedded preload * fix lint and test * fix test...
This commit is contained in:
		
							parent
							
								
									4b0da0e97a
								
							
						
					
					
						commit
						828e22b17f
					
				| @ -3,6 +3,7 @@ package callbacks | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| @ -10,6 +11,98 @@ import ( | |||||||
| 	"gorm.io/gorm/utils" | 	"gorm.io/gorm/utils" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // parsePreloadMap extracts nested preloads. e.g.
 | ||||||
|  | //
 | ||||||
|  | //	// schema has a "k0" relation and a "k7.k8" embedded relation
 | ||||||
|  | //	parsePreloadMap(schema, map[string][]interface{}{
 | ||||||
|  | //		clause.Associations: {"arg1"},
 | ||||||
|  | //		"k1":                {"arg2"},
 | ||||||
|  | //		"k2.k3":             {"arg3"},
 | ||||||
|  | //		"k4.k5.k6":          {"arg4"},
 | ||||||
|  | //	})
 | ||||||
|  | //	// preloadMap is
 | ||||||
|  | //	map[string]map[string][]interface{}{
 | ||||||
|  | //		"k0": {},
 | ||||||
|  | //		"k7": {
 | ||||||
|  | //			"k8": {},
 | ||||||
|  | //		},
 | ||||||
|  | //		"k1": {},
 | ||||||
|  | //		"k2": {
 | ||||||
|  | //			"k3": {"arg3"},
 | ||||||
|  | //		},
 | ||||||
|  | //		"k4": {
 | ||||||
|  | //			"k5.k6": {"arg4"},
 | ||||||
|  | //		},
 | ||||||
|  | //	}
 | ||||||
|  | func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { | ||||||
|  | 	preloadMap := map[string]map[string][]interface{}{} | ||||||
|  | 	setPreloadMap := func(name, value string, args []interface{}) { | ||||||
|  | 		if _, ok := preloadMap[name]; !ok { | ||||||
|  | 			preloadMap[name] = map[string][]interface{}{} | ||||||
|  | 		} | ||||||
|  | 		if value != "" { | ||||||
|  | 			preloadMap[name][value] = args | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for name, args := range preloads { | ||||||
|  | 		preloadFields := strings.Split(name, ".") | ||||||
|  | 		value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") | ||||||
|  | 		if preloadFields[0] == clause.Associations { | ||||||
|  | 			for _, relation := range s.Relationships.Relations { | ||||||
|  | 				if relation.Schema == s { | ||||||
|  | 					setPreloadMap(relation.Name, value, args) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { | ||||||
|  | 				for _, value := range embeddedValues(embeddedRelations) { | ||||||
|  | 					setPreloadMap(embedded, value, args) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			setPreloadMap(preloadFields[0], value, args) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return preloadMap | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func embeddedValues(embeddedRelations *schema.Relationships) []string { | ||||||
|  | 	if embeddedRelations == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) | ||||||
|  | 	for _, relation := range embeddedRelations.Relations { | ||||||
|  | 		// skip first struct name
 | ||||||
|  | 		names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) | ||||||
|  | 	} | ||||||
|  | 	for _, relations := range embeddedRelations.EmbeddedRelations { | ||||||
|  | 		names = append(names, embeddedValues(relations)...) | ||||||
|  | 	} | ||||||
|  | 	return names | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { | ||||||
|  | 	if relationships == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	preloadMap := parsePreloadMap(s, preloads) | ||||||
|  | 	for name := range preloadMap { | ||||||
|  | 		if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { | ||||||
|  | 			if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} else if rel := relationships.Relations[name]; rel != nil { | ||||||
|  | 			if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | ||||||
| 	var ( | 	var ( | ||||||
| 		reflectValue     = tx.Statement.ReflectValue | 		reflectValue     = tx.Statement.ReflectValue | ||||||
|  | |||||||
| @ -267,32 +267,7 @@ func Preload(db *gorm.DB) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		preloadMap := map[string]map[string][]interface{}{} | 		preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) | ||||||
| 		for name := range db.Statement.Preloads { |  | ||||||
| 			preloadFields := strings.Split(name, ".") |  | ||||||
| 			if preloadFields[0] == clause.Associations { |  | ||||||
| 				for _, rel := range db.Statement.Schema.Relationships.Relations { |  | ||||||
| 					if rel.Schema == db.Statement.Schema { |  | ||||||
| 						if _, ok := preloadMap[rel.Name]; !ok { |  | ||||||
| 							preloadMap[rel.Name] = map[string][]interface{}{} |  | ||||||
| 						} |  | ||||||
| 
 |  | ||||||
| 						if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { |  | ||||||
| 							preloadMap[rel.Name][value] = db.Statement.Preloads[name] |  | ||||||
| 						} |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				if _, ok := preloadMap[preloadFields[0]]; !ok { |  | ||||||
| 					preloadMap[preloadFields[0]] = map[string][]interface{}{} |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { |  | ||||||
| 					preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		preloadNames := make([]string, 0, len(preloadMap)) | 		preloadNames := make([]string, 0, len(preloadMap)) | ||||||
| 		for key := range preloadMap { | 		for key := range preloadMap { | ||||||
| 			preloadNames = append(preloadNames, key) | 			preloadNames = append(preloadNames, key) | ||||||
| @ -312,7 +287,9 @@ func Preload(db *gorm.DB) { | |||||||
| 		preloadDB.Statement.Unscoped = db.Statement.Unscoped | 		preloadDB.Statement.Unscoped = db.Statement.Unscoped | ||||||
| 
 | 
 | ||||||
| 		for _, name := range preloadNames { | 		for _, name := range preloadNames { | ||||||
| 			if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { | 			if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { | ||||||
|  | 				db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) | ||||||
|  | 			} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { | ||||||
| 				db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) | 				db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) | ||||||
| 			} else { | 			} else { | ||||||
| 				db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) | 				db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) | ||||||
|  | |||||||
| @ -89,6 +89,10 @@ type Field struct { | |||||||
| 	NewValuePool           FieldNewValuePool | 	NewValuePool           FieldNewValuePool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (field *Field) BindName() string { | ||||||
|  | 	return strings.Join(field.BindNames, ".") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // ParseField parses reflect.StructField to Field
 | // ParseField parses reflect.StructField to Field
 | ||||||
| func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||||
| 	var ( | 	var ( | ||||||
|  | |||||||
| @ -27,6 +27,8 @@ type Relationships struct { | |||||||
| 	HasMany   []*Relationship | 	HasMany   []*Relationship | ||||||
| 	Many2Many []*Relationship | 	Many2Many []*Relationship | ||||||
| 	Relations map[string]*Relationship | 	Relations map[string]*Relationship | ||||||
|  | 
 | ||||||
|  | 	EmbeddedRelations map[string]*Relationships | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type Relationship struct { | type Relationship struct { | ||||||
| @ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if schema.err == nil { | 	if schema.err == nil { | ||||||
| 		schema.Relationships.Relations[relation.Name] = relation | 		schema.setRelation(relation) | ||||||
| 		switch relation.Type { | 		switch relation.Type { | ||||||
| 		case HasOne: | 		case HasOne: | ||||||
| 			schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) | 			schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) | ||||||
| @ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | |||||||
| 	return relation | 	return relation | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (schema *Schema) setRelation(relation *Relationship) { | ||||||
|  | 	// set non-embedded relation
 | ||||||
|  | 	if rel := schema.Relationships.Relations[relation.Name]; rel != nil { | ||||||
|  | 		if len(rel.Field.BindNames) > 1 { | ||||||
|  | 			schema.Relationships.Relations[relation.Name] = relation | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		schema.Relationships.Relations[relation.Name] = relation | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// set embedded relation
 | ||||||
|  | 	if len(relation.Field.BindNames) <= 1 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	relationships := &schema.Relationships | ||||||
|  | 	for i, name := range relation.Field.BindNames { | ||||||
|  | 		if i < len(relation.Field.BindNames)-1 { | ||||||
|  | 			if relationships.EmbeddedRelations == nil { | ||||||
|  | 				relationships.EmbeddedRelations = map[string]*Relationships{} | ||||||
|  | 			} | ||||||
|  | 			if r := relationships.EmbeddedRelations[name]; r == nil { | ||||||
|  | 				relationships.EmbeddedRelations[name] = &Relationships{} | ||||||
|  | 			} | ||||||
|  | 			relationships = relationships.EmbeddedRelations[name] | ||||||
|  | 		} else { | ||||||
|  | 			if relationships.Relations == nil { | ||||||
|  | 				relationships.Relations = map[string]*Relationship{} | ||||||
|  | 			} | ||||||
|  | 			relationships.Relations[relation.Name] = relation | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
 | // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
 | ||||||
| //
 | //
 | ||||||
| //	type User struct {
 | //	type User struct {
 | ||||||
| @ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if primaryKeyField == nil { | ||||||
|  | 			schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		// use same data type for foreign keys
 | 		// use same data type for foreign keys
 | ||||||
| 		if copyableDataType(primaryKeyField.DataType) { | 		if copyableDataType(primaryKeyField.DataType) { | ||||||
| 			relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType | 			relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType | ||||||
| @ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu | |||||||
| 			primaryFields = primarySchema.PrimaryFields | 			primaryFields = primarySchema.PrimaryFields | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 	primaryFieldLoop: | ||||||
| 		for _, primaryField := range primaryFields { | 		for _, primaryField := range primaryFields { | ||||||
| 			lookUpName := primarySchemaName + primaryField.Name | 			lookUpName := primarySchemaName + primaryField.Name | ||||||
| 			if gl == guessBelongs { | 			if gl == guessBelongs { | ||||||
| @ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu | |||||||
| 				lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) | 				lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | 			for _, name := range lookUpNames { | ||||||
|  | 				if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { | ||||||
|  | 					foreignFields = append(foreignFields, f) | ||||||
|  | 					primaryFields = append(primaryFields, primaryField) | ||||||
|  | 					continue primaryFieldLoop | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			for _, name := range lookUpNames { | 			for _, name := range lookUpNames { | ||||||
| 				if f := foreignSchema.LookUpField(name); f != nil { | 				if f := foreignSchema.LookUpField(name); f != nil { | ||||||
| 					foreignFields = append(foreignFields, f) | 					foreignFields = append(foreignFields, f) | ||||||
| 					primaryFields = append(primaryFields, primaryField) | 					primaryFields = append(primaryFields, primaryField) | ||||||
| 					break | 					continue primaryFieldLoop | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestEmbeddedHas(t *testing.T) { | ||||||
|  | 	type Toy struct { | ||||||
|  | 		ID        int | ||||||
|  | 		Name      string | ||||||
|  | 		OwnerID   int | ||||||
|  | 		OwnerType string | ||||||
|  | 	} | ||||||
|  | 	type User struct { | ||||||
|  | 		ID  int | ||||||
|  | 		Cat struct { | ||||||
|  | 			Name string | ||||||
|  | 			Toy  Toy   `gorm:"polymorphic:Owner;"` | ||||||
|  | 			Toys []Toy `gorm:"polymorphic:Owner;"` | ||||||
|  | 		} `gorm:"embedded;embeddedPrefix:cat_"` | ||||||
|  | 		Dog struct { | ||||||
|  | 			ID     int | ||||||
|  | 			Name   string | ||||||
|  | 			UserID int | ||||||
|  | 			Toy    Toy   `gorm:"polymorphic:Owner;"` | ||||||
|  | 			Toys   []Toy `gorm:"polymorphic:Owner;"` | ||||||
|  | 		} | ||||||
|  | 		Toys []Toy `gorm:"polymorphic:Owner;"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to parse schema, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ | ||||||
|  | 		"Cat": { | ||||||
|  | 			Relations: map[string]Relation{ | ||||||
|  | 				"Toy": { | ||||||
|  | 					Name:        "Toy", | ||||||
|  | 					Type:        schema.HasOne, | ||||||
|  | 					Schema:      "User", | ||||||
|  | 					FieldSchema: "Toy", | ||||||
|  | 					Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, | ||||||
|  | 					References: []Reference{ | ||||||
|  | 						{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, | ||||||
|  | 						{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				"Toys": { | ||||||
|  | 					Name:        "Toys", | ||||||
|  | 					Type:        schema.HasMany, | ||||||
|  | 					Schema:      "User", | ||||||
|  | 					FieldSchema: "Toy", | ||||||
|  | 					Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, | ||||||
|  | 					References: []Reference{ | ||||||
|  | 						{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, | ||||||
|  | 						{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestEmbeddedBelongsTo(t *testing.T) { | ||||||
|  | 	type Country struct { | ||||||
|  | 		ID   int `gorm:"primaryKey"` | ||||||
|  | 		Name string | ||||||
|  | 	} | ||||||
|  | 	type Address struct { | ||||||
|  | 		CountryID int | ||||||
|  | 		Country   Country | ||||||
|  | 	} | ||||||
|  | 	type NestedAddress struct { | ||||||
|  | 		Address | ||||||
|  | 	} | ||||||
|  | 	type Org struct { | ||||||
|  | 		ID              int | ||||||
|  | 		PostalAddress   Address `gorm:"embedded;embeddedPrefix:postal_address_"` | ||||||
|  | 		VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` | ||||||
|  | 		AddressID       int | ||||||
|  | 		Address         struct { | ||||||
|  | 			ID int | ||||||
|  | 			Address | ||||||
|  | 		} | ||||||
|  | 		NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("Failed to parse schema, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ | ||||||
|  | 		"PostalAddress": { | ||||||
|  | 			Relations: map[string]Relation{ | ||||||
|  | 				"Country": { | ||||||
|  | 					Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", | ||||||
|  | 					References: []Reference{ | ||||||
|  | 						{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"VisitingAddress": { | ||||||
|  | 			Relations: map[string]Relation{ | ||||||
|  | 				"Country": { | ||||||
|  | 					Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", | ||||||
|  | 					References: []Reference{ | ||||||
|  | 						{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		"NestedAddress": { | ||||||
|  | 			EmbeddedRelations: map[string]EmbeddedRelations{ | ||||||
|  | 				"Address": { | ||||||
|  | 					Relations: map[string]Relation{ | ||||||
|  | 						"Country": { | ||||||
|  | 							Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", | ||||||
|  | 							References: []Reference{ | ||||||
|  | 								{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestVariableRelation(t *testing.T) { | func TestVariableRelation(t *testing.T) { | ||||||
| 	var result struct { | 	var result struct { | ||||||
| 		User | 		User | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"go/ast" | 	"go/ast" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| @ -25,6 +26,7 @@ type Schema struct { | |||||||
| 	PrimaryFieldDBNames       []string | 	PrimaryFieldDBNames       []string | ||||||
| 	Fields                    []*Field | 	Fields                    []*Field | ||||||
| 	FieldsByName              map[string]*Field | 	FieldsByName              map[string]*Field | ||||||
|  | 	FieldsByBindName          map[string]*Field // embedded fields is 'Embed.Field'
 | ||||||
| 	FieldsByDBName            map[string]*Field | 	FieldsByDBName            map[string]*Field | ||||||
| 	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
 | 	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
 | ||||||
| 	Relationships             Relationships | 	Relationships             Relationships | ||||||
| @ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // LookUpFieldByBindName looks for the closest field in the embedded struct.
 | ||||||
|  | //
 | ||||||
|  | //	type Struct struct {
 | ||||||
|  | //		Embedded struct {
 | ||||||
|  | //			ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
 | ||||||
|  | //		}
 | ||||||
|  | //		ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
 | ||||||
|  | //	}
 | ||||||
|  | func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { | ||||||
|  | 	if len(bindNames) == 0 { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	for i := len(bindNames) - 1; i >= 0; i-- { | ||||||
|  | 		find := strings.Join(bindNames[:i], ".") + "." + name | ||||||
|  | 		if field, ok := schema.FieldsByBindName[find]; ok { | ||||||
|  | 			return field | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type Tabler interface { | type Tabler interface { | ||||||
| 	TableName() string | 	TableName() string | ||||||
| } | } | ||||||
| @ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	schema := &Schema{ | 	schema := &Schema{ | ||||||
| 		Name:           modelType.Name(), | 		Name:             modelType.Name(), | ||||||
| 		ModelType:      modelType, | 		ModelType:        modelType, | ||||||
| 		Table:          tableName, | 		Table:            tableName, | ||||||
| 		FieldsByName:   map[string]*Field{}, | 		FieldsByName:     map[string]*Field{}, | ||||||
| 		FieldsByDBName: map[string]*Field{}, | 		FieldsByBindName: map[string]*Field{}, | ||||||
| 		Relationships:  Relationships{Relations: map[string]*Relationship{}}, | 		FieldsByDBName:   map[string]*Field{}, | ||||||
| 		cacheStore:     cacheStore, | 		Relationships:    Relationships{Relations: map[string]*Relationship{}}, | ||||||
| 		namer:          namer, | 		cacheStore:       cacheStore, | ||||||
| 		initialized:    make(chan struct{}), | 		namer:            namer, | ||||||
|  | 		initialized:      make(chan struct{}), | ||||||
| 	} | 	} | ||||||
| 	// When the schema initialization is completed, the channel will be closed
 | 	// When the schema initialization is completed, the channel will be closed
 | ||||||
| 	defer close(schema.initialized) | 	defer close(schema.initialized) | ||||||
| @ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | |||||||
| 			field.DBName = namer.ColumnName(schema.Table, field.Name) | 			field.DBName = namer.ColumnName(schema.Table, field.Name) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		bindName := field.BindName() | ||||||
| 		if field.DBName != "" { | 		if field.DBName != "" { | ||||||
| 			// nonexistence or shortest path or first appear prioritized if has permission
 | 			// nonexistence or shortest path or first appear prioritized if has permission
 | ||||||
| 			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { | 			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { | ||||||
| @ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | |||||||
| 				} | 				} | ||||||
| 				schema.FieldsByDBName[field.DBName] = field | 				schema.FieldsByDBName[field.DBName] = field | ||||||
| 				schema.FieldsByName[field.Name] = field | 				schema.FieldsByName[field.Name] = field | ||||||
|  | 				schema.FieldsByBindName[bindName] = field | ||||||
| 
 | 
 | ||||||
| 				if v != nil && v.PrimaryKey { | 				if v != nil && v.PrimaryKey { | ||||||
| 					for idx, f := range schema.PrimaryFields { | 					for idx, f := range schema.PrimaryFields { | ||||||
| @ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | |||||||
| 		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { | 		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { | ||||||
| 			schema.FieldsByName[field.Name] = field | 			schema.FieldsByName[field.Name] = field | ||||||
| 		} | 		} | ||||||
|  | 		if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { | ||||||
|  | 			schema.FieldsByBindName[bindName] = field | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 		field.setupValuerAndSetter() | 		field.setupValuerAndSetter() | ||||||
| 	} | 	} | ||||||
| @ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam | |||||||
| 					return schema, schema.err | 					return schema, schema.err | ||||||
| 				} else { | 				} else { | ||||||
| 					schema.FieldsByName[field.Name] = field | 					schema.FieldsByName[field.Name] = field | ||||||
|  | 					schema.FieldsByBindName[field.BindName()] = field | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type EmbeddedRelations struct { | ||||||
|  | 	Relations         map[string]Relation | ||||||
|  | 	EmbeddedRelations map[string]EmbeddedRelations | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { | ||||||
|  | 	for name, relations := range actual { | ||||||
|  | 		rs := expected[name] | ||||||
|  | 		t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { | ||||||
|  | 			if len(relations.Relations) != len(rs.Relations) { | ||||||
|  | 				t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) | ||||||
|  | 			} | ||||||
|  | 			if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { | ||||||
|  | 				t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) | ||||||
|  | 			} | ||||||
|  | 			for n, rel := range relations.Relations { | ||||||
|  | 				if r, ok := rs.Relations[n]; !ok { | ||||||
|  | 					t.Errorf("failed to find relation by name %s", n) | ||||||
|  | 				} else { | ||||||
|  | 					checkSchemaRelation(t, &schema.Schema{ | ||||||
|  | 						Relationships: schema.Relationships{ | ||||||
|  | 							Relations: map[string]*schema.Relationship{n: rel}, | ||||||
|  | 						}, | ||||||
|  | 					}, r) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { | func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { | ||||||
| 	for k, v := range values { | 	for k, v := range values { | ||||||
| 		t.Run("CheckField/"+k, func(t *testing.T) { | 		t.Run("CheckField/"+k, func(t *testing.T) { | ||||||
|  | |||||||
| @ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { | |||||||
| 	DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) | 	DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) | ||||||
| 	CheckUserUnscoped(t, *user6, user) | 	CheckUserUnscoped(t, *user6, user) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestEmbedPreload(t *testing.T) { | ||||||
|  | 	type Country struct { | ||||||
|  | 		ID   int `gorm:"primaryKey"` | ||||||
|  | 		Name string | ||||||
|  | 	} | ||||||
|  | 	type EmbeddedAddress struct { | ||||||
|  | 		ID        int | ||||||
|  | 		Name      string | ||||||
|  | 		CountryID *int | ||||||
|  | 		Country   *Country | ||||||
|  | 	} | ||||||
|  | 	type NestedAddress struct { | ||||||
|  | 		EmbeddedAddress | ||||||
|  | 	} | ||||||
|  | 	type Org struct { | ||||||
|  | 		ID              int | ||||||
|  | 		PostalAddress   EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` | ||||||
|  | 		VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` | ||||||
|  | 		AddressID       *int | ||||||
|  | 		Address         *EmbeddedAddress | ||||||
|  | 		NestedAddress   NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) | ||||||
|  | 	DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) | ||||||
|  | 
 | ||||||
|  | 	org := Org{ | ||||||
|  | 		PostalAddress:   EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, | ||||||
|  | 		VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, | ||||||
|  | 		Address:         &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, | ||||||
|  | 		NestedAddress: NestedAddress{ | ||||||
|  | 			EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	if err := DB.Create(&org).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to create org, got err: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name     string | ||||||
|  | 		preloads map[string][]interface{} | ||||||
|  | 		expect   Org | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "address country", | ||||||
|  | 			preloads: map[string][]interface{}{"Address.Country": {}}, | ||||||
|  | 			expect: Org{ | ||||||
|  | 				ID: org.ID, | ||||||
|  | 				PostalAddress: EmbeddedAddress{ | ||||||
|  | 					ID:        org.PostalAddress.ID, | ||||||
|  | 					Name:      org.PostalAddress.Name, | ||||||
|  | 					CountryID: org.PostalAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}, | ||||||
|  | 				VisitingAddress: EmbeddedAddress{ | ||||||
|  | 					ID:        org.VisitingAddress.ID, | ||||||
|  | 					Name:      org.VisitingAddress.Name, | ||||||
|  | 					CountryID: org.VisitingAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}, | ||||||
|  | 				AddressID: org.AddressID, | ||||||
|  | 				Address:   org.Address, | ||||||
|  | 				NestedAddress: NestedAddress{EmbeddedAddress{ | ||||||
|  | 					ID:        org.NestedAddress.ID, | ||||||
|  | 					Name:      org.NestedAddress.Name, | ||||||
|  | 					CountryID: org.NestedAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name:     "postal address country", | ||||||
|  | 			preloads: map[string][]interface{}{"PostalAddress.Country": {}}, | ||||||
|  | 			expect: Org{ | ||||||
|  | 				ID:            org.ID, | ||||||
|  | 				PostalAddress: org.PostalAddress, | ||||||
|  | 				VisitingAddress: EmbeddedAddress{ | ||||||
|  | 					ID:        org.VisitingAddress.ID, | ||||||
|  | 					Name:      org.VisitingAddress.Name, | ||||||
|  | 					CountryID: org.VisitingAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}, | ||||||
|  | 				AddressID: org.AddressID, | ||||||
|  | 				Address:   nil, | ||||||
|  | 				NestedAddress: NestedAddress{EmbeddedAddress{ | ||||||
|  | 					ID:        org.NestedAddress.ID, | ||||||
|  | 					Name:      org.NestedAddress.Name, | ||||||
|  | 					CountryID: org.NestedAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name:     "nested address country", | ||||||
|  | 			preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, | ||||||
|  | 			expect: Org{ | ||||||
|  | 				ID: org.ID, | ||||||
|  | 				PostalAddress: EmbeddedAddress{ | ||||||
|  | 					ID:        org.PostalAddress.ID, | ||||||
|  | 					Name:      org.PostalAddress.Name, | ||||||
|  | 					CountryID: org.PostalAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}, | ||||||
|  | 				VisitingAddress: EmbeddedAddress{ | ||||||
|  | 					ID:        org.VisitingAddress.ID, | ||||||
|  | 					Name:      org.VisitingAddress.Name, | ||||||
|  | 					CountryID: org.VisitingAddress.CountryID, | ||||||
|  | 					Country:   nil, | ||||||
|  | 				}, | ||||||
|  | 				AddressID:     org.AddressID, | ||||||
|  | 				Address:       nil, | ||||||
|  | 				NestedAddress: org.NestedAddress, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "associations", | ||||||
|  | 			preloads: map[string][]interface{}{ | ||||||
|  | 				clause.Associations: {}, | ||||||
|  | 				// clause.Associations won’t preload nested associations
 | ||||||
|  | 				"Address.Country": {}, | ||||||
|  | 			}, | ||||||
|  | 			expect: org, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB = DB.Debug() | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		t.Run(test.name, func(t *testing.T) { | ||||||
|  | 			actual := Org{} | ||||||
|  | 			tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) | ||||||
|  | 			for name, args := range test.preloads { | ||||||
|  | 				tx = tx.Preload(name, args...) | ||||||
|  | 			} | ||||||
|  | 			if err := tx.Find(&actual).Error; err != nil { | ||||||
|  | 				t.Errorf("failed to find org, got err: %v", err) | ||||||
|  | 			} | ||||||
|  | 			AssertEqual(t, actual, test.expect) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 black-06
						black-06