From 1d9e563023dfb03c829e6675e08536b429fd5c09 Mon Sep 17 00:00:00 2001 From: riverchu Date: Fri, 3 Sep 2021 23:09:20 +0800 Subject: [PATCH 01/12] style: prepose error judgement --- callbacks/update.go | 50 +++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 75bb02db..d85c4c22 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -51,37 +51,39 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } + if db.Error != nil { + return + } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { - return - } - db.Statement.Build(db.Statement.BuildClauses...) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) } + } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { return } + db.Statement.Build(db.Statement.BuildClauses...) + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } } } From c89862279137298f794351ace2dad9c1e487b327 Mon Sep 17 00:00:00 2001 From: riverchu Date: Sun, 5 Sep 2021 11:10:48 +0800 Subject: [PATCH 02/12] test: add testcase in TestSave --- tests/update_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/update_test.go b/tests/update_test.go index 5ad1bb39..869df769 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -642,6 +642,36 @@ func TestSave(t *testing.T) { if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } + + user3 := *GetUser("save3", Config{}) + DB.Create(&user3) + + if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user3.Name = "save3_" + DB.Model(User{}).Save(&user3) + + var result2 User + if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { + t.Fatalf("failed to find updated user") + } + + DB.Model(User{}).Save(&struct { + gorm.Model + Placeholder string + Name string + }{ + Model: user3.Model, + Placeholder: "placeholder", + Name: "save3__", + }) + + var result3 User + if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { + t.Fatalf("failed to find updated user") + } } func TestSaveWithPrimaryValue(t *testing.T) { From 4581e8b590a83d730dc490e8731990f467ba9e4f Mon Sep 17 00:00:00 2001 From: riverchu Date: Sun, 5 Sep 2021 23:07:28 +0800 Subject: [PATCH 03/12] test: update Save test --- tests/update_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/update_test.go b/tests/update_test.go index 869df769..2a747ce5 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -651,14 +651,14 @@ func TestSave(t *testing.T) { } user3.Name = "save3_" - DB.Model(User{}).Save(&user3) + DB.Model(User{Model: user3.Model}).Save(&user3) var result2 User if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { t.Fatalf("failed to find updated user") } - DB.Model(User{}).Save(&struct { + DB.Debug().Model(User{Model: user3.Model}).Save(&struct { gorm.Model Placeholder string Name string From eaa63d15e7ac3bab9ea2fd946b19e411ad261dc6 Mon Sep 17 00:00:00 2001 From: riverchu Date: Sun, 5 Sep 2021 23:12:24 +0800 Subject: [PATCH 04/12] feat: copy dest fields to model struct --- callbacks/update.go | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index d85c4c22..ee60bcd7 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -23,11 +23,38 @@ func SetupUpdateReflectValue(db *gorm.DB) { rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) } } + } else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct { + db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem()) } } } } +func findType(target interface{}) reflect.Type { + t := reflect.TypeOf(target) + if t.Kind() == reflect.Ptr { + return t.Elem() + } + return t +} + +func transToModel(from, to reflect.Value) interface{} { + if from.String() == to.String() { + return from.Interface() + } + + fromType := from.Type() + for i := 0; i < fromType.NumField(); i++ { + fieldName := fromType.Field(i).Name + fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName) + if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() { + continue + } + toField.Set(fromField) + } + return to.Interface() +} + func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -227,7 +254,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { value, isZero := field.ValueOf(updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { From 895c1178a0d1d837cd986c45eac62f6b10a6add4 Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Thu, 8 Jul 2021 10:04:40 +0200 Subject: [PATCH 05/12] Proposal, Add Specific on for Joins queries --- callbacks/query.go | 47 ++++++++++++++++++++++++++-------------------- chainable_api.go | 6 ++++++ statement.go | 1 + 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 3299d015..e5f1250c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,33 +125,40 @@ func BuildQuerySQL(db *gorm.DB) { }) } - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { + if join.On != nil { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } } } + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: []clause.Expression{join.On}}, + }) } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) } else { joins = append(joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/chainable_api.go b/chainable_api.go index 88279044..32943a83 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,6 +177,12 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return } +func (db *DB) JoinsOn(query string, on clause.Expression, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: on}) + return +} + // Group specify the group method on the find func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() diff --git a/statement.go b/statement.go index 93b78c12..89824bc1 100644 --- a/statement.go +++ b/statement.go @@ -50,6 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} + On clause.Expression } // StatementModifier statement modifier interface From 52cc438d07cef6975b3407594c612f8e856b88af Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Sat, 17 Jul 2021 15:45:15 +0200 Subject: [PATCH 06/12] JoinsOn unit test + use all primary keys --- callbacks/query.go | 10 ++++++++-- chainable_api.go | 2 +- statement.go | 2 +- tests/joins_test.go | 20 ++++++++++++++++++++ utils/tests/models.go | 2 ++ 5 files changed, 32 insertions(+), 4 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index e5f1250c..570a85d0 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,7 +125,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } - if join.On != nil { + if join.On == nil { exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { @@ -153,10 +153,16 @@ func BuildQuerySQL(db *gorm.DB) { ON: clause.Where{Exprs: exprs}, }) } else { + primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) + for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { + primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} + } + + exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: []clause.Expression{join.On}}, + ON: clause.Where{Exprs: exprs}, }) } } else { diff --git a/chainable_api.go b/chainable_api.go index 32943a83..184931ff 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,7 +177,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return } -func (db *DB) JoinsOn(query string, on clause.Expression, args ...interface{}) (tx *DB) { +func (db *DB) JoinsOn(query string, on interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: on}) return diff --git a/statement.go b/statement.go index 89824bc1..b21b8854 100644 --- a/statement.go +++ b/statement.go @@ -50,7 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} - On clause.Expression + On interface{} } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 46611f5f..0b46d69c 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -104,6 +104,26 @@ func TestJoinConds(t *testing.T) { } } +func TestJoinOn(t *testing.T) { + var user = *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) + + if err := DB.JoinsOn("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) + var user2 User + if err := DB.JoinsOn("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") +} + func TestJoinsWithSelect(t *testing.T) { type result struct { ID uint diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..8e833c93 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,6 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) type User struct { gorm.Model Name string @@ -18,6 +19,7 @@ type User struct { Birthday *time.Time Account Account Pets []*Pet + NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company From c301aeb524234036192ceaca1a7bee18ce1de4fa Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Sun, 18 Jul 2021 12:04:18 +0200 Subject: [PATCH 07/12] Refactor for readability --- callbacks/query.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 570a85d0..a4093c63 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,7 +125,19 @@ func BuildQuerySQL(db *gorm.DB) { }) } - if join.On == nil { + if join.On != nil { + primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) + for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { + primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} + } + + exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { @@ -147,18 +159,7 @@ func BuildQuerySQL(db *gorm.DB) { } } } - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) - for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { - primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} - } - exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, From d047f854e66b669785cbe6be8227269807db1782 Mon Sep 17 00:00:00 2001 From: Adrien Carreira Date: Sat, 28 Aug 2021 10:27:19 +0200 Subject: [PATCH 08/12] PR Comments --- chainable_api.go | 15 +++++++++------ tests/joins_test.go | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 184931ff..8fd7ee3c 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -171,15 +171,18 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) - return -} -func (db *DB) JoinsOn(query string, on interface{}, args ...interface{}) (tx *DB) { - tx = db.getInstance() - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: on}) + if len(args) > 0 { + if db, ok := args[0].(*DB); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: db}) + return + } + } + + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/tests/joins_test.go b/tests/joins_test.go index 0b46d69c..21c73c19 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -111,14 +111,14 @@ func TestJoinOn(t *testing.T) { var user1 User onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) - if err := DB.JoinsOn("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) var user2 User - if err := DB.JoinsOn("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") From 3b6a7c8aecd66eb78e0f22710cc203b7abe0c894 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 12:01:19 +0800 Subject: [PATCH 09/12] Update sqlserver driver --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index a1033a60..d7ab65ad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.8 + gorm.io/driver/sqlserver v1.0.9 gorm.io/gorm v1.21.14 ) From 6c94b07e98eca77e3ba1ca2e2341a5f5b75a0727 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 15:30:14 +0800 Subject: [PATCH 10/12] try to fix fatal error: concurrent map read and map write --- schema/schema.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 0e0501d4..faba2e21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -119,20 +119,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + if v, loaded := cacheStore.Load(modelType); loaded { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } - defer func() { - if schema.err != nil { - logger.Default.Error(context.Background(), schema.err.Error()) - cacheStore.Delete(modelType) - } - }() - for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { @@ -233,6 +226,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { From ba16b2368f253572195de14fef62272a752595ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 20:04:54 +0800 Subject: [PATCH 11/12] Refactor update record (#4679) --- callbacks/update.go | 81 +++++++++++++++++--------------------------- tests/update_test.go | 12 ++++--- 2 files changed, 40 insertions(+), 53 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index ee60bcd7..7d5ea4a4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -23,38 +23,11 @@ func SetupUpdateReflectValue(db *gorm.DB) { rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) } } - } else if modelType, destType := findType(db.Statement.Model), findType(db.Statement.Dest); modelType.Kind() == reflect.Struct && destType.Kind() == reflect.Struct { - db.Statement.Dest = transToModel(reflect.Indirect(reflect.ValueOf(db.Statement.Dest)), reflect.New(modelType).Elem()) } } } } -func findType(target interface{}) reflect.Type { - t := reflect.TypeOf(target) - if t.Kind() == reflect.Ptr { - return t.Elem() - } - return t -} - -func transToModel(from, to reflect.Value) interface{} { - if from.String() == to.String() { - return from.Interface() - } - - fromType := from.Type() - for i := 0; i < fromType.NumField(); i++ { - fieldName := fromType.Field(i).Name - fromField, toField := from.FieldByName(fieldName), to.FieldByName(fieldName) - if !toField.IsValid() || !toField.CanSet() || toField.Kind() != fromField.Kind() { - continue - } - toField.Set(fromField) - } - return to.Interface() -} - func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -249,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: + var updatingSchema = stmt.Schema + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + } + } + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks && field.AutoUpdateTime > 0 { - if field.AutoUpdateTime == schema.UnixNanosecond { - value = stmt.DB.NowFunc().UnixNano() - } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { - value = stmt.DB.NowFunc().Unix() + if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false } - isZero = false - } - if ok || !isZero { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } - } - } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } diff --git a/tests/update_test.go b/tests/update_test.go index 2a747ce5..9e5b630e 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -651,14 +651,16 @@ func TestSave(t *testing.T) { } user3.Name = "save3_" - DB.Model(User{Model: user3.Model}).Save(&user3) + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } var result2 User if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { - t.Fatalf("failed to find updated user") + t.Fatalf("failed to find updated user, got %v", err) } - DB.Debug().Model(User{Model: user3.Model}).Save(&struct { + if err := DB.Model(User{Model: user3.Model}).Save(&struct { gorm.Model Placeholder string Name string @@ -666,7 +668,9 @@ func TestSave(t *testing.T) { Model: user3.Model, Placeholder: "placeholder", Name: "save3__", - }) + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } var result3 User if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { From a16db07945e5f5acf348649debd2130dfcfeeb92 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 7 Sep 2021 21:21:44 +0800 Subject: [PATCH 12/12] Refactor Join ON --- callbacks/query.go | 69 +++++++++++++++++++++++---------------------- chainable_api.go | 4 ++- statement.go | 2 +- tests/joins_test.go | 5 ++-- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index a4093c63..1cfd618c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -125,47 +125,48 @@ func BuildQuerySQL(db *gorm.DB) { }) } - if join.On != nil { - primaryFields := make([]clause.Column, len(relation.FieldSchema.PrimaryFieldDBNames)) - for idx, ref := range relation.FieldSchema.PrimaryFieldDBNames { - primaryFields[idx] = clause.Column{Table: tableAliasName, Name: ref} - } - - exprs := db.Statement.BuildCondition("(?) = (?)", primaryFields, join.On) - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) - } else { - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, } } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, } } } - - joins = append(joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) } + + if join.On != nil { + onStmt := gorm.Statement{Table: tableAliasName, DB: db} + join.On.Build(&onStmt) + onSQL := onStmt.SQL.String() + vars := onStmt.Vars + for idx, v := range onStmt.Vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) } else { joins = append(joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/chainable_api.go b/chainable_api.go index 8fd7ee3c..01ab2597 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -177,7 +177,9 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { if len(args) > 0 { if db, ok := args[0].(*DB); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: db}) + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + } return } } diff --git a/statement.go b/statement.go index b21b8854..38363443 100644 --- a/statement.go +++ b/statement.go @@ -50,7 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} - On interface{} + On *clause.Where } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 21c73c19..e560f38a 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -109,14 +109,15 @@ func TestJoinOn(t *testing.T) { DB.Save(&user) var user1 User - onQuery := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_1").Model(&Pet{}) + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") - onQuery2 := DB.Select("id").Where("user_id = users.id AND name = ?", "joins-on_pet_2").Model(&Pet{}) + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) var user2 User if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err)