Fix failed to guess relations for embedded types, close #3224
This commit is contained in:
		
							parent
							
								
									c11c939b95
								
							
						
					
					
						commit
						ff985b90cc
					
				| @ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||
| 				} | ||||
| 				return nil | ||||
| 			}); err != nil { | ||||
| 				fmt.Println(err) | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -62,6 +62,7 @@ type Field struct { | ||||
| 	TagSettings           map[string]string | ||||
| 	Schema                *Schema | ||||
| 	EmbeddedSchema        *Schema | ||||
| 	OwnerSchema           *Schema | ||||
| 	ReflectValueOf        func(reflect.Value) reflect.Value | ||||
| 	ValueOf               func(reflect.Value) (value interface{}, zero bool) | ||||
| 	Set                   func(reflect.Value, interface{}) error | ||||
| @ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 			} | ||||
| 			for _, ef := range field.EmbeddedSchema.Fields { | ||||
| 				ef.Schema = schema | ||||
| 				ef.OwnerSchema = field.EmbeddedSchema | ||||
| 				ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) | ||||
| 				// index is negative means is pointer
 | ||||
| 				if field.FieldType.Kind() == reflect.Struct { | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/jinzhu/inflection" | ||||
| 	"gorm.io/gorm/clause" | ||||
| @ -66,10 +67,17 @@ func (schema *Schema) parseRelation(field *Field) { | ||||
| 		} | ||||
| 	) | ||||
| 
 | ||||
| 	if field.OwnerSchema != nil { | ||||
| 		if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { | ||||
| 			schema.err = err | ||||
| 			return | ||||
| 		} | ||||
| 	} else { | ||||
| 		if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { | ||||
| 			schema.err = err | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { | ||||
| 		schema.buildPolymorphicRelation(relation, field, polymorphic) | ||||
| @ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) { | ||||
| 	} else { | ||||
| 		switch field.IndirectFieldType.Kind() { | ||||
| 		case reflect.Struct, reflect.Slice: | ||||
| 			schema.guessRelation(relation, field, true) | ||||
| 			schema.guessRelation(relation, field, guessHas) | ||||
| 		default: | ||||
| 			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) | ||||
| 		} | ||||
| @ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { | ||||
| type guessLevel int | ||||
| 
 | ||||
| const ( | ||||
| 	guessHas guessLevel = iota | ||||
| 	guessEmbeddedHas | ||||
| 	guessBelongs | ||||
| 	guessEmbeddedBelongs | ||||
| ) | ||||
| 
 | ||||
| func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { | ||||
| 	var ( | ||||
| 		primaryFields, foreignFields []*Field | ||||
| 		primarySchema, foreignSchema = schema, relation.FieldSchema | ||||
| 	) | ||||
| 
 | ||||
| 	if !guessHas { | ||||
| 		primarySchema, foreignSchema = relation.FieldSchema, schema | ||||
| 	reguessOrErr := func(err string, args ...interface{}) { | ||||
| 		switch gl { | ||||
| 		case guessHas: | ||||
| 			schema.guessRelation(relation, field, guessEmbeddedHas) | ||||
| 		case guessEmbeddedHas: | ||||
| 			schema.guessRelation(relation, field, guessBelongs) | ||||
| 		case guessBelongs: | ||||
| 			schema.guessRelation(relation, field, guessEmbeddedBelongs) | ||||
| 		default: | ||||
| 			schema.err = fmt.Errorf(err, args...) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	reguessOrErr := func(err string, args ...interface{}) { | ||||
| 		if guessHas { | ||||
| 			schema.guessRelation(relation, field, false) | ||||
| 	switch gl { | ||||
| 	case guessEmbeddedHas: | ||||
| 		if field.OwnerSchema != nil { | ||||
| 			primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema | ||||
| 		} else { | ||||
| 			schema.err = fmt.Errorf(err, args...) | ||||
| 			reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||
| 			return | ||||
| 		} | ||||
| 	case guessBelongs: | ||||
| 		primarySchema, foreignSchema = relation.FieldSchema, schema | ||||
| 	case guessEmbeddedBelongs: | ||||
| 		if field.OwnerSchema != nil { | ||||
| 			primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema | ||||
| 		} else { | ||||
| 			reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | ||||
| 		} | ||||
| 	} else { | ||||
| 		for _, primaryField := range primarySchema.PrimaryFields { | ||||
| 			lookUpName := schema.Name + primaryField.Name | ||||
| 			if !guessHas { | ||||
| 			lookUpName := primarySchema.Name + primaryField.Name | ||||
| 			if gl == guessBelongs { | ||||
| 				lookUpName = field.Name + primaryField.Name | ||||
| 			} | ||||
| 
 | ||||
| @ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | ||||
| 	} | ||||
| 
 | ||||
| 	if len(foreignFields) == 0 { | ||||
| 		reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) | ||||
| 		reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) | ||||
| 		return | ||||
| 	} else if len(relation.primaryKeys) > 0 { | ||||
| 		for idx, primaryKey := range relation.primaryKeys { | ||||
| @ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | ||||
| 		relation.References = append(relation.References, &Reference{ | ||||
| 			PrimaryKey:    primaryFields[idx], | ||||
| 			ForeignKey:    foreignField, | ||||
| 			OwnPrimaryKey: schema == primarySchema && guessHas, | ||||
| 			OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	if guessHas { | ||||
| 	if gl == guessHas || gl == guessEmbeddedHas { | ||||
| 		relation.Type = "has" | ||||
| 	} else { | ||||
| 		relation.Type = BelongsTo | ||||
|  | ||||
| @ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) { | ||||
| 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, | ||||
| 			results:   []string{"c5", "c1", "c2", "c4", "c3"}, | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, | ||||
| 			results:   []string{"c3", "c5", "c1", "c2", "c4"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, | ||||
| 			results:   []string{"c5", "c1", "c2", "c3", "c4"}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestEmbeddedStruct(t *testing.T) { | ||||
| @ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) { | ||||
| 		t.Errorf("Failed to create got error %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestEmbeddedRelations(t *testing.T) { | ||||
| 	type AdvancedUser struct { | ||||
| 		User     `gorm:"embedded"` | ||||
| 		Advanced bool | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Debug().Migrator().DropTable(&AdvancedUser{}) | ||||
| 
 | ||||
| 	if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { | ||||
| 		t.Errorf("Failed to auto migrate advanced user, got error %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu