fix(Joins): args with select and omit
This commit is contained in:
		
							parent
							
								
									3f20a543fa
								
							
						
					
					
						commit
						bcaac9eb18
					
				@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) {
 | 
				
			|||||||
				} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
 | 
									} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
 | 
				
			||||||
					tableAliasName := relation.Name
 | 
										tableAliasName := relation.Name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										columnStmt := gorm.Statement{
 | 
				
			||||||
 | 
											Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
 | 
				
			||||||
 | 
											Selects: join.Selects, Omits: join.Omits,
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
 | 
				
			||||||
					for _, s := range relation.FieldSchema.DBNames {
 | 
										for _, s := range relation.FieldSchema.DBNames {
 | 
				
			||||||
						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
 | 
											if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
 | 
				
			||||||
							Table: tableAliasName,
 | 
												clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
 | 
				
			||||||
							Name:  s,
 | 
													Table: tableAliasName,
 | 
				
			||||||
							Alias: tableAliasName + "__" + s,
 | 
													Name:  s,
 | 
				
			||||||
						})
 | 
													Alias: tableAliasName + "__" + s,
 | 
				
			||||||
 | 
												})
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					exprs := make([]clause.Expression, len(relation.References))
 | 
										exprs := make([]clause.Expression, len(relation.References))
 | 
				
			||||||
 | 
				
			|||||||
@ -10,10 +10,11 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Model specify the model you would like to run db operations
 | 
					// Model specify the model you would like to run db operations
 | 
				
			||||||
//    // update all users's name to `hello`
 | 
					//
 | 
				
			||||||
//    db.Model(&User{}).Update("name", "hello")
 | 
					//	// update all users's name to `hello`
 | 
				
			||||||
//    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | 
					//	db.Model(&User{}).Update("name", "hello")
 | 
				
			||||||
//    db.Model(&user).Update("name", "hello")
 | 
					//	// if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | 
				
			||||||
 | 
					//	db.Model(&user).Update("name", "hello")
 | 
				
			||||||
func (db *DB) Model(value interface{}) (tx *DB) {
 | 
					func (db *DB) Model(value interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
	tx.Statement.Model = value
 | 
						tx.Statement.Model = value
 | 
				
			||||||
@ -179,18 +180,21 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Joins specify Joins conditions
 | 
					// Joins specify Joins conditions
 | 
				
			||||||
//     db.Joins("Account").Find(&user)
 | 
					//
 | 
				
			||||||
//     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | 
					//	db.Joins("Account").Find(&user)
 | 
				
			||||||
//     db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
 | 
					//	db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | 
				
			||||||
 | 
					//	db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
 | 
				
			||||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
 | 
					func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(args) == 1 {
 | 
						if len(args) == 1 {
 | 
				
			||||||
		if db, ok := args[0].(*DB); ok {
 | 
							if db, ok := args[0].(*DB); ok {
 | 
				
			||||||
 | 
								j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
 | 
				
			||||||
			if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
 | 
								if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
 | 
				
			||||||
				tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
 | 
									j.On = &where
 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								tx.Statement.Joins = append(tx.Statement.Joins, j)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -219,8 +223,9 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Order specify order when retrieve records from database
 | 
					// Order specify order when retrieve records from database
 | 
				
			||||||
//     db.Order("name DESC")
 | 
					//
 | 
				
			||||||
//     db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
 | 
					//	db.Order("name DESC")
 | 
				
			||||||
 | 
					//	db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
 | 
				
			||||||
func (db *DB) Order(value interface{}) (tx *DB) {
 | 
					func (db *DB) Order(value interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -256,17 +261,18 @@ func (db *DB) Offset(offset int) (tx *DB) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
 | 
					// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
 | 
				
			||||||
//     func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
 | 
					 | 
				
			||||||
//         return db.Where("amount > ?", 1000)
 | 
					 | 
				
			||||||
//     }
 | 
					 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//     func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
 | 
					//	func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
 | 
				
			||||||
//         return func (db *gorm.DB) *gorm.DB {
 | 
					//	    return db.Where("amount > ?", 1000)
 | 
				
			||||||
//             return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
 | 
					//	}
 | 
				
			||||||
//         }
 | 
					 | 
				
			||||||
//     }
 | 
					 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | 
					//	func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
					//	    return func (db *gorm.DB) *gorm.DB {
 | 
				
			||||||
 | 
					//	        return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
 | 
				
			||||||
 | 
					//	    }
 | 
				
			||||||
 | 
					//	}
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//	db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
 | 
				
			||||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
 | 
					func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
	tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
 | 
						tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
 | 
				
			||||||
@ -274,7 +280,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Preload preload associations with given conditions
 | 
					// Preload preload associations with given conditions
 | 
				
			||||||
//    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | 
					//
 | 
				
			||||||
 | 
					//	db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
 | 
				
			||||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
 | 
					func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
 | 
				
			||||||
	tx = db.getInstance()
 | 
						tx = db.getInstance()
 | 
				
			||||||
	if tx.Statement.Preloads == nil {
 | 
						if tx.Statement.Preloads == nil {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										13
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								statement.go
									
									
									
									
									
								
							@ -49,9 +49,11 @@ type Statement struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type join struct {
 | 
					type join struct {
 | 
				
			||||||
	Name  string
 | 
						Name    string
 | 
				
			||||||
	Conds []interface{}
 | 
						Conds   []interface{}
 | 
				
			||||||
	On    *clause.Where
 | 
						On      *clause.Where
 | 
				
			||||||
 | 
						Selects []string
 | 
				
			||||||
 | 
						Omits   []string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// StatementModifier statement modifier interface
 | 
					// StatementModifier statement modifier interface
 | 
				
			||||||
@ -544,8 +546,9 @@ func (stmt *Statement) clone() *Statement {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetColumn set column's value
 | 
					// SetColumn set column's value
 | 
				
			||||||
//   stmt.SetColumn("Name", "jinzhu") // Hooks Method
 | 
					//
 | 
				
			||||||
//   stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
 | 
					//	stmt.SetColumn("Name", "jinzhu") // Hooks Method
 | 
				
			||||||
 | 
					//	stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
 | 
				
			||||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
 | 
					func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
 | 
				
			||||||
	if v, ok := stmt.Dest.(map[string]interface{}); ok {
 | 
						if v, ok := stmt.Dest.(map[string]interface{}); ok {
 | 
				
			||||||
		v[name] = value
 | 
							v[name] = value
 | 
				
			||||||
 | 
				
			|||||||
@ -260,3 +260,47 @@ func TestJoinWithSameColumnName(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("wrong pet name")
 | 
							t.Fatalf("wrong pet name")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestJoinArgsWithDB(t *testing.T) {
 | 
				
			||||||
 | 
						user := *GetUser("joins-args-db", Config{Pets: 2})
 | 
				
			||||||
 | 
						DB.Save(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test where
 | 
				
			||||||
 | 
						var user1 User
 | 
				
			||||||
 | 
						onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
 | 
				
			||||||
 | 
						if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Failed to load with joins on, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test where and omit
 | 
				
			||||||
 | 
						onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
 | 
				
			||||||
 | 
						var user2 User
 | 
				
			||||||
 | 
						if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Failed to load with joins on, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
 | 
				
			||||||
 | 
						AssertEqual(t, user2.NamedPet.Name, "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test where and select
 | 
				
			||||||
 | 
						onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
 | 
				
			||||||
 | 
						var user3 User
 | 
				
			||||||
 | 
						if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Failed to load with joins on, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						AssertEqual(t, user3.NamedPet.ID, 0)
 | 
				
			||||||
 | 
						AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// test select
 | 
				
			||||||
 | 
						onQuery4 := DB.Select("ID")
 | 
				
			||||||
 | 
						var user4 User
 | 
				
			||||||
 | 
						if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Failed to load with joins on, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if user4.NamedPet.ID == 0 {
 | 
				
			||||||
 | 
							t.Fatal("Pet ID can not be empty")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						AssertEqual(t, user4.NamedPet.Name, "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user