fix: preload panic when model and dest different close #5130
commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e
Author: Jinzhu <wosmvp@gmail.com>
Date:   Fri Mar 18 13:37:22 2022 +0800
    Refactor #5130
commit 40cbba49f374c9bae54f80daee16697ae45e905b
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:36:56 2022 +0800
    test: fix test fail
commit 66d3f078291102a30532b6a9d97c757228a9b543
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 17:29:09 2022 +0800
    test: drop table and auto migrate
commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42
Author: chenrui <chenrui@jingdaka.com>
Date:   Sat Mar 5 15:27:45 2022 +0800
    fix: preload panic when model and dest different
			
			
This commit is contained in:
		
							parent
							
								
									c2e36ebe62
								
							
						
					
					
						commit
						5431da8caf
					
				| @ -10,10 +10,9 @@ import ( | ||||
| 	"gorm.io/gorm/utils" | ||||
| ) | ||||
| 
 | ||||
| func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { | ||||
| func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { | ||||
| 	var ( | ||||
| 		reflectValue     = db.Statement.ReflectValue | ||||
| 		tx               = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) | ||||
| 		reflectValue     = tx.Statement.ReflectValue | ||||
| 		relForeignKeys   []string | ||||
| 		relForeignFields []*schema.Field | ||||
| 		foreignFields    []*schema.Field | ||||
| @ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 		inlineConds      []interface{} | ||||
| 	) | ||||
| 
 | ||||
| 	db.Statement.Settings.Range(func(k, v interface{}) bool { | ||||
| 		tx.Statement.Settings.Store(k, v) | ||||
| 		return true | ||||
| 	}) | ||||
| 
 | ||||
| 	if rel.JoinTable != nil { | ||||
| 		var ( | ||||
| 			joinForeignFields    = make([]*schema.Field, 0, len(rel.References)) | ||||
| @ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) | ||||
| 		joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) | ||||
| 		if len(joinForeignValues) == 0 { | ||||
| 			return | ||||
| 			return nil | ||||
| 		} | ||||
| 
 | ||||
| 		joinResults := rel.JoinTable.MakeSlice().Elem() | ||||
| 		column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) | ||||
| 		db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) | ||||
| 		if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		// convert join identity map to relation identity map
 | ||||
| 		fieldValues := make([]interface{}, len(joinForeignFields)) | ||||
| @ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 		for i := 0; i < joinResults.Len(); i++ { | ||||
| 			joinIndexValue := joinResults.Index(i) | ||||
| 			for idx, field := range joinForeignFields { | ||||
| 				fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) | ||||
| 				fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) | ||||
| 			} | ||||
| 
 | ||||
| 			for idx, field := range joinRelForeignFields { | ||||
| 				joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) | ||||
| 				joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) | ||||
| 			} | ||||
| 
 | ||||
| 			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { | ||||
| @ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		_, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) | ||||
| 		_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) | ||||
| 	} else { | ||||
| 		for _, ref := range rel.References { | ||||
| 			if ref.OwnPrimaryKey { | ||||
| @ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) | ||||
| 		identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) | ||||
| 		if len(foreignValues) == 0 { | ||||
| 			return | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| @ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) | ||||
| 		if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	fieldValues := make([]interface{}, len(relForeignFields)) | ||||
| @ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 	case reflect.Struct: | ||||
| 		switch rel.Type { | ||||
| 		case schema.HasMany, schema.Many2Many: | ||||
| 			rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) | ||||
| 			rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) | ||||
| 		default: | ||||
| 			rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) | ||||
| 			rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) | ||||
| 		} | ||||
| 	case reflect.Slice, reflect.Array: | ||||
| 		for i := 0; i < reflectValue.Len(); i++ { | ||||
| 			switch rel.Type { | ||||
| 			case schema.HasMany, schema.Many2Many: | ||||
| 				rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) | ||||
| 				rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) | ||||
| 			default: | ||||
| 				rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) | ||||
| 				rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 	for i := 0; i < reflectResults.Len(); i++ { | ||||
| 		elem := reflectResults.Index(i) | ||||
| 		for idx, field := range relForeignFields { | ||||
| 			fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) | ||||
| 			fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) | ||||
| 		} | ||||
| 
 | ||||
| 		datas, ok := identityMap[utils.ToStringKey(fieldValues...)] | ||||
| 		if !ok { | ||||
| 			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", | ||||
| 				elem.Interface())) | ||||
| 			continue | ||||
| 			return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, data := range datas { | ||||
| 			reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) | ||||
| 			reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) | ||||
| 			if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { | ||||
| 				reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) | ||||
| 			} | ||||
| @ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			reflectFieldValue = reflect.Indirect(reflectFieldValue) | ||||
| 			switch reflectFieldValue.Kind() { | ||||
| 			case reflect.Struct: | ||||
| 				rel.Field.Set(db.Statement.Context, data, elem.Interface()) | ||||
| 				rel.Field.Set(tx.Statement.Context, data, elem.Interface()) | ||||
| 			case reflect.Slice, reflect.Array: | ||||
| 				if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { | ||||
| 					rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||
| 					rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) | ||||
| 				} else { | ||||
| 					rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | ||||
| 					rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return tx.Error | ||||
| } | ||||
|  | ||||
| @ -237,9 +237,20 @@ func Preload(db *gorm.DB) { | ||||
| 		} | ||||
| 		sort.Strings(preloadNames) | ||||
| 
 | ||||
| 		preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) | ||||
| 		db.Statement.Settings.Range(func(k, v interface{}) bool { | ||||
| 			preloadDB.Statement.Settings.Store(k, v) | ||||
| 			return true | ||||
| 		}) | ||||
| 
 | ||||
| 		if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		preloadDB.Statement.ReflectValue = db.Statement.ReflectValue | ||||
| 
 | ||||
| 		for _, name := range preloadNames { | ||||
| 			if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { | ||||
| 				preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) | ||||
| 			if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { | ||||
| 				db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) | ||||
| 			} else { | ||||
| 				db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) | ||||
| 			} | ||||
|  | ||||
| @ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { | ||||
| 	} else if tables := strings.Split(name, "."); len(tables) == 2 { | ||||
| 		tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} | ||||
| 		tx.Statement.Table = tables[1] | ||||
| 	} else { | ||||
| 	} else if name != "" { | ||||
| 		tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} | ||||
| 		tx.Statement.Table = name | ||||
| 	} else { | ||||
| 		tx.Statement.TableExpr = nil | ||||
| 		tx.Statement.Table = "" | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| @ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { | ||||
| 	} | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { | ||||
| 		t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) | ||||
| 		t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) | ||||
| 	} | ||||
| 
 | ||||
| 	if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { | ||||
|  | ||||
| @ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) { | ||||
| 	} | ||||
| 	wg.Wait() | ||||
| } | ||||
| 
 | ||||
| func TestPreloadWithDiffModel(t *testing.T) { | ||||
| 	user := *GetUser("preload_with_diff_model", Config{Account: true}) | ||||
| 
 | ||||
| 	if err := DB.Create(&user).Error; err != nil { | ||||
| 		t.Fatalf("errors happened when create: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var result struct { | ||||
| 		Something string | ||||
| 		User | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( | ||||
| 		"users.*, 'yo' as something").First(&result, "name = ?", user.Name) | ||||
| 
 | ||||
| 	CheckUser(t, user, result.User) | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 chenrui
						chenrui