diff --git a/callbacks/query.go b/callbacks/query.go index 3299d015..1cfd618c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -147,6 +147,21 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.On != nil { + onStmt := gorm.Statement{Table: tableAliasName, DB: db} + join.On.Build(&onStmt) + onSQL := onStmt.SQL.String() + vars := onStmt.Vars + for idx, v := range onStmt.Vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, diff --git a/chainable_api.go b/chainable_api.go index 88279044..01ab2597 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -171,8 +171,19 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + + if len(args) > 0 { + if db, ok := args[0].(*DB); ok { + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + } + return + } + } + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/schema/schema.go b/schema/schema.go index 0e0501d4..faba2e21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -119,20 +119,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + if v, loaded := cacheStore.Load(modelType); loaded { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } - defer func() { - if schema.err != nil { - logger.Default.Error(context.Background(), schema.err.Error()) - cacheStore.Delete(modelType) - } - }() - for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { @@ -233,6 +226,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { diff --git a/statement.go b/statement.go index 93b78c12..38363443 100644 --- a/statement.go +++ b/statement.go @@ -50,6 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} + On *clause.Where } // StatementModifier statement modifier interface diff --git a/tests/go.mod b/tests/go.mod index a1033a60..d7ab65ad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.8 + gorm.io/driver/sqlserver v1.0.9 gorm.io/gorm v1.21.14 ) diff --git a/tests/joins_test.go b/tests/joins_test.go index 46611f5f..e560f38a 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -104,6 +104,27 @@ func TestJoinConds(t *testing.T) { } } +func TestJoinOn(t *testing.T) { + var user = *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) + + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") +} + func TestJoinsWithSelect(t *testing.T) { type result struct { ID uint diff --git a/tests/update_test.go b/tests/update_test.go index 5ad1bb39..9e5b630e 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -642,6 +642,40 @@ func TestSave(t *testing.T) { if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } + + user3 := *GetUser("save3", Config{}) + DB.Create(&user3) + + if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user3.Name = "save3_" + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } + + var result2 User + if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { + t.Fatalf("failed to find updated user, got %v", err) + } + + if err := DB.Model(User{Model: user3.Model}).Save(&struct { + gorm.Model + Placeholder string + Name string + }{ + Model: user3.Model, + Placeholder: "placeholder", + Name: "save3__", + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } + + var result3 User + if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { + t.Fatalf("failed to find updated user") + } } func TestSaveWithPrimaryValue(t *testing.T) { diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..8e833c93 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,6 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) type User struct { gorm.Model Name string @@ -18,6 +19,7 @@ type User struct { Birthday *time.Time Account Account Pets []*Pet + NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company