Test inner joins
This commit is contained in:
		
							parent
							
								
									85246682c8
								
							
						
					
					
						commit
						9dfed613db
					
				| @ -28,7 +28,8 @@ func Query(db *gorm.DB) { | |||||||
| 			if len(db.Statement.Selects) == 0 { | 			if len(db.Statement.Selects) == 0 { | ||||||
| 				for _, dbName := range db.Statement.Schema.DBNames { | 				for _, dbName := range db.Statement.Schema.DBNames { | ||||||
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
| 						Name: dbName, | 						Table: db.Statement.Table, | ||||||
|  | 						Name:  dbName, | ||||||
| 					}) | 					}) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @ -37,8 +38,9 @@ func Query(db *gorm.DB) { | |||||||
| 				if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { | 				if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { | ||||||
| 					for _, s := range relation.FieldSchema.DBNames { | 					for _, s := range relation.FieldSchema.DBNames { | ||||||
| 						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | 						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ | ||||||
| 							Table: relation.FieldSchema.Table, | 							Table: relation.Name, | ||||||
| 							Name:  s, | 							Name:  s, | ||||||
|  | 							Alias: relation.Name + "__" + s, | ||||||
| 						}) | 						}) | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| @ -46,16 +48,16 @@ func Query(db *gorm.DB) { | |||||||
| 					for _, ref := range relation.References { | 					for _, ref := range relation.References { | ||||||
| 						if ref.OwnPrimaryKey { | 						if ref.OwnPrimaryKey { | ||||||
| 							exprs = append(exprs, clause.Expr{ | 							exprs = append(exprs, clause.Expr{ | ||||||
| 								SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.FieldSchema.Table, ref.ForeignKey.DBName), | 								SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.PrimaryKey.DBName, relation.Name, ref.ForeignKey.DBName), | ||||||
| 							}) | 							}) | ||||||
| 						} else { | 						} else { | ||||||
| 							if ref.PrimaryValue == "" { | 							if ref.PrimaryValue == "" { | ||||||
| 								exprs = append(exprs, clause.Expr{ | 								exprs = append(exprs, clause.Expr{ | ||||||
| 									SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.FieldSchema.Table, ref.PrimaryKey.DBName), | 									SQL: fmt.Sprintf("%s.%s = %s.%s", db.Statement.Schema.Table, ref.ForeignKey.DBName, relation.Name, ref.PrimaryKey.DBName), | ||||||
| 								}) | 								}) | ||||||
| 							} else { | 							} else { | ||||||
| 								exprs = append(exprs, clause.Expr{ | 								exprs = append(exprs, clause.Expr{ | ||||||
| 									SQL:  fmt.Sprintf("%s.%s = ?", relation.FieldSchema.Table, ref.PrimaryKey.DBName), | 									SQL:  fmt.Sprintf("%s.%s = ?", relation.Name, ref.PrimaryKey.DBName), | ||||||
| 									Vars: []interface{}{ref.PrimaryValue}, | 									Vars: []interface{}{ref.PrimaryValue}, | ||||||
| 								}) | 								}) | ||||||
| 							} | 							} | ||||||
| @ -64,7 +66,7 @@ func Query(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 					joins = append(joins, clause.Join{ | 					joins = append(joins, clause.Join{ | ||||||
| 						Type:  clause.LeftJoin, | 						Type:  clause.LeftJoin, | ||||||
| 						Table: clause.Table{Name: relation.FieldSchema.Table}, | 						Table: clause.Table{Name: relation.FieldSchema.Table, Alias: relation.Name}, | ||||||
| 						ON:    clause.Where{Exprs: exprs}, | 						ON:    clause.Where{Exprs: exprs}, | ||||||
| 					}) | 					}) | ||||||
| 				} else { | 				} else { | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ package callbacks | |||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm" | 	"github.com/jinzhu/gorm" | ||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
| @ -54,12 +55,21 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr | 			isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr | ||||||
| 			db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) | 			db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0)) | ||||||
| 			fields := make([]*schema.Field, len(columns)) | 			fields := make([]*schema.Field, len(columns)) | ||||||
|  | 			joinFields := make([][2]*schema.Field, len(columns)) | ||||||
| 
 | 
 | ||||||
| 			for idx, column := range columns { | 			for idx, column := range columns { | ||||||
| 				if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | 				if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||||
| 					fields[idx] = field | 					fields[idx] = field | ||||||
|  | 				} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||||
|  | 					if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { | ||||||
|  | 						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||||
|  | 							joinFields[idx] = [2]*schema.Field{rel.Field, field} | ||||||
|  | 							continue | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					values[idx] = &sql.RawBytes{} | ||||||
| 				} else { | 				} else { | ||||||
| 					values[idx] = sql.RawBytes{} | 					values[idx] = &sql.RawBytes{} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| @ -68,6 +78,9 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 				for idx, field := range fields { | 				for idx, field := range fields { | ||||||
| 					if field != nil { | 					if field != nil { | ||||||
| 						values[idx] = field.ReflectValueOf(elem).Addr().Interface() | 						values[idx] = field.ReflectValueOf(elem).Addr().Interface() | ||||||
|  | 					} else if joinFields[idx][0] != nil { | ||||||
|  | 						relValue := joinFields[idx][0].ReflectValueOf(elem) | ||||||
|  | 						values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| @ -86,8 +99,17 @@ func Scan(rows *sql.Rows, db *gorm.DB) { | |||||||
| 			for idx, column := range columns { | 			for idx, column := range columns { | ||||||
| 				if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | 				if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||||
|  | 				} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||||
|  | 					if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { | ||||||
|  | 						relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) | ||||||
|  | 						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||||
|  | 							values[idx] = field.ReflectValueOf(relValue).Addr().Interface() | ||||||
|  | 							continue | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					values[idx] = &sql.RawBytes{} | ||||||
| 				} else { | 				} else { | ||||||
| 					values[idx] = sql.RawBytes{} | 					values[idx] = &sql.RawBytes{} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -7,9 +7,75 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestJoins(t *testing.T, db *gorm.DB) { | func TestJoins(t *testing.T, db *gorm.DB) { | ||||||
| 	db.Migrator().DropTable(&User{}) | 	db.Migrator().DropTable(&User{}, &Account{}, &Company{}) | ||||||
| 	db.AutoMigrate(&User{}) | 	db.AutoMigrate(&User{}, &Account{}, &Company{}) | ||||||
|  | 
 | ||||||
|  | 	check := func(t *testing.T, oldUser, newUser User) { | ||||||
|  | 		if newUser.Company.ID != oldUser.Company.ID { | ||||||
|  | 			t.Errorf("Company is not equal when load with joins, loaded company id: %v", newUser.Company.ID) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if newUser.Manager == nil || newUser.Manager.ID != oldUser.Manager.ID { | ||||||
|  | 			t.Errorf("Manager is not equal when load with joins: loaded manager: %+v", newUser.Manager) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if newUser.Account.ID != oldUser.Account.ID { | ||||||
|  | 			t.Errorf("Account is not equal when load with joins, loaded account id: %v, expect: %v", newUser.Account.ID, oldUser.Account.ID) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	t.Run("Joins", func(t *testing.T) { | 	t.Run("Joins", func(t *testing.T) { | ||||||
|  | 		user := User{ | ||||||
|  | 			Name:    "joins-1", | ||||||
|  | 			Company: Company{Name: "company"}, | ||||||
|  | 			Manager: &User{Name: "manager"}, | ||||||
|  | 			Account: Account{Number: "account-has-one-association"}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		db.Create(&user) | ||||||
|  | 
 | ||||||
|  | 		var user2 User | ||||||
|  | 		if err := db.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { | ||||||
|  | 			t.Fatalf("Failed to load with joins, got error: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		check(t, user, user2) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("JoinsForSlice", func(t *testing.T) { | ||||||
|  | 		users := []User{{ | ||||||
|  | 			Name:    "slice-joins-1", | ||||||
|  | 			Company: Company{Name: "company"}, | ||||||
|  | 			Manager: &User{Name: "manager"}, | ||||||
|  | 			Account: Account{Number: "account-has-one-association"}, | ||||||
|  | 		}, { | ||||||
|  | 			Name:    "slice-joins-2", | ||||||
|  | 			Company: Company{Name: "company2"}, | ||||||
|  | 			Manager: &User{Name: "manager2"}, | ||||||
|  | 			Account: Account{Number: "account-has-one-association2"}, | ||||||
|  | 		}, { | ||||||
|  | 			Name:    "slice-joins-3", | ||||||
|  | 			Company: Company{Name: "company3"}, | ||||||
|  | 			Manager: &User{Name: "manager3"}, | ||||||
|  | 			Account: Account{Number: "account-has-one-association3"}, | ||||||
|  | 		}} | ||||||
|  | 
 | ||||||
|  | 		db.Create(&users) | ||||||
|  | 
 | ||||||
|  | 		var users2 []User | ||||||
|  | 		if err := db.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.name LIKE ?", "slice-joins%").Error; err != nil { | ||||||
|  | 			t.Fatalf("Failed to load with joins, got error: %v", err) | ||||||
|  | 		} else if len(users2) != len(users) { | ||||||
|  | 			t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for _, u2 := range users2 { | ||||||
|  | 			for _, u := range users { | ||||||
|  | 				if u.Name == u2.Name { | ||||||
|  | 					check(t, u, u2) | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu