Fix failed to guess relations for embedded types, close #3224
This commit is contained in:
		
							parent
							
								
									c11c939b95
								
							
						
					
					
						commit
						ff985b90cc
					
				| @ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 				} | 				} | ||||||
| 				return nil | 				return nil | ||||||
| 			}); err != nil { | 			}); err != nil { | ||||||
|  | 				fmt.Println(err) | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -62,6 +62,7 @@ type Field struct { | |||||||
| 	TagSettings           map[string]string | 	TagSettings           map[string]string | ||||||
| 	Schema                *Schema | 	Schema                *Schema | ||||||
| 	EmbeddedSchema        *Schema | 	EmbeddedSchema        *Schema | ||||||
|  | 	OwnerSchema           *Schema | ||||||
| 	ReflectValueOf        func(reflect.Value) reflect.Value | 	ReflectValueOf        func(reflect.Value) reflect.Value | ||||||
| 	ValueOf               func(reflect.Value) (value interface{}, zero bool) | 	ValueOf               func(reflect.Value) (value interface{}, zero bool) | ||||||
| 	Set                   func(reflect.Value, interface{}) error | 	Set                   func(reflect.Value, interface{}) error | ||||||
| @ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 			} | 			} | ||||||
| 			for _, ef := range field.EmbeddedSchema.Fields { | 			for _, ef := range field.EmbeddedSchema.Fields { | ||||||
| 				ef.Schema = schema | 				ef.Schema = schema | ||||||
|  | 				ef.OwnerSchema = field.EmbeddedSchema | ||||||
| 				ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) | 				ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) | ||||||
| 				// index is negative means is pointer
 | 				// index is negative means is pointer
 | ||||||
| 				if field.FieldType.Kind() == reflect.Struct { | 				if field.FieldType.Kind() == reflect.Struct { | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/inflection" | 	"github.com/jinzhu/inflection" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| @ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) { | |||||||
| 		} | 		} | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { | 	if field.OwnerSchema != nil { | ||||||
| 		schema.err = err | 		if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { | ||||||
| 		return | 			schema.err = err | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { | ||||||
|  | 			schema.err = err | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { | 	if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { | ||||||
| @ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) { | |||||||
| 	} else { | 	} else { | ||||||
| 		switch field.IndirectFieldType.Kind() { | 		switch field.IndirectFieldType.Kind() { | ||||||
| 		case reflect.Struct, reflect.Slice: | 		case reflect.Struct, reflect.Slice: | ||||||
| 			schema.guessRelation(relation, field, true) | 			schema.guessRelation(relation, field, guessHas) | ||||||
| 		default: | 		default: | ||||||
| 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) | 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) | ||||||
| 		} | 		} | ||||||
| @ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { | type guessLevel int | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	guessHas guessLevel = iota | ||||||
|  | 	guessEmbeddedHas | ||||||
|  | 	guessBelongs | ||||||
|  | 	guessEmbeddedBelongs | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { | ||||||
| 	var ( | 	var ( | ||||||
| 		primaryFields, foreignFields []*Field | 		primaryFields, foreignFields []*Field | ||||||
| 		primarySchema, foreignSchema = schema, relation.FieldSchema | 		primarySchema, foreignSchema = schema, relation.FieldSchema | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	if !guessHas { | 	reguessOrErr := func(err string, args ...interface{}) { | ||||||
| 		primarySchema, foreignSchema = relation.FieldSchema, schema | 		switch gl { | ||||||
|  | 		case guessHas: | ||||||
|  | 			schema.guessRelation(relation, field, guessEmbeddedHas) | ||||||
|  | 		case guessEmbeddedHas: | ||||||
|  | 			schema.guessRelation(relation, field, guessBelongs) | ||||||
|  | 		case guessBelongs: | ||||||
|  | 			schema.guessRelation(relation, field, guessEmbeddedBelongs) | ||||||
|  | 		default: | ||||||
|  | 			schema.err = fmt.Errorf(err, args...) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	reguessOrErr := func(err string, args ...interface{}) { | 	switch gl { | ||||||
| 		if guessHas { | 	case guessEmbeddedHas: | ||||||
| 			schema.guessRelation(relation, field, false) | 		if field.OwnerSchema != nil { | ||||||
|  | 			primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema | ||||||
| 		} else { | 		} else { | ||||||
| 			schema.err = fmt.Errorf(err, args...) | 			reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	case guessBelongs: | ||||||
|  | 		primarySchema, foreignSchema = relation.FieldSchema, schema | ||||||
|  | 	case guessEmbeddedBelongs: | ||||||
|  | 		if field.OwnerSchema != nil { | ||||||
|  | 			primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema | ||||||
|  | 		} else { | ||||||
|  | 			reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||||
|  | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | |||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		for _, primaryField := range primarySchema.PrimaryFields { | 		for _, primaryField := range primarySchema.PrimaryFields { | ||||||
| 			lookUpName := schema.Name + primaryField.Name | 			lookUpName := primarySchema.Name + primaryField.Name | ||||||
| 			if !guessHas { | 			if gl == guessBelongs { | ||||||
| 				lookUpName = field.Name + primaryField.Name | 				lookUpName = field.Name + primaryField.Name | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(foreignFields) == 0 { | 	if len(foreignFields) == 0 { | ||||||
| 		reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) | 		reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||||
| 		return | 		return | ||||||
| 	} else if len(relation.primaryKeys) > 0 { | 	} else if len(relation.primaryKeys) > 0 { | ||||||
| 		for idx, primaryKey := range relation.primaryKeys { | 		for idx, primaryKey := range relation.primaryKeys { | ||||||
| @ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | |||||||
| 		relation.References = append(relation.References, &Reference{ | 		relation.References = append(relation.References, &Reference{ | ||||||
| 			PrimaryKey:    primaryFields[idx], | 			PrimaryKey:    primaryFields[idx], | ||||||
| 			ForeignKey:    foreignField, | 			ForeignKey:    foreignField, | ||||||
| 			OwnPrimaryKey: schema == primarySchema && guessHas, | 			OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if guessHas { | 	if gl == guessHas || gl == guessEmbeddedHas { | ||||||
| 		relation.Type = "has" | 		relation.Type = "has" | ||||||
| 	} else { | 	} else { | ||||||
| 		relation.Type = BelongsTo | 		relation.Type = BelongsTo | ||||||
|  | |||||||
| @ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) { | |||||||
| 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, | 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, | ||||||
| 			results:   []string{"c5", "c1", "c2", "c4", "c3"}, | 			results:   []string{"c3", "c5", "c1", "c2", "c4"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, | ||||||
|  | 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestEmbeddedStruct(t *testing.T) { | func TestEmbeddedStruct(t *testing.T) { | ||||||
| @ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) { | |||||||
| 		t.Errorf("Failed to create got error %v", err) | 		t.Errorf("Failed to create got error %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestEmbeddedRelations(t *testing.T) { | ||||||
|  | 	type AdvancedUser struct { | ||||||
|  | 		User     `gorm:"embedded"` | ||||||
|  | 		Advanced bool | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Debug().Migrator().DropTable(&AdvancedUser{}) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { | ||||||
|  | 		t.Errorf("Failed to auto migrate advanced user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu