diff --git a/callbacks/query.go b/callbacks/query.go index 26ee8c34..a0e75a7e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -177,7 +177,7 @@ func BuildQuerySQL(db *gorm.DB) { } fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: clause.LeftJoin, + Type: join.JoinType, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) diff --git a/chainable_api.go b/chainable_api.go index 68b4d1aa..1fbefc0e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -183,18 +183,26 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where, JoinType: joinType}) return } } } - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) return } diff --git a/statement.go b/statement.go index 12687810..aa3937ae 100644 --- a/statement.go +++ b/statement.go @@ -49,9 +49,10 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where + Name string + Conds []interface{} + On *clause.Where + JoinType clause.JoinType } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 4908e5ba..3fbbf17c 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -229,3 +229,19 @@ func TestJoinWithSoftDeleted(t *testing.T) { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } + +func TestInnerJoins(t *testing.T) { + user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + var err error + err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user2, user) + + // NamedPet is nil + err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, gorm.ErrRecordNotFound) +}