From bd7e42ec651f66539009371675bff38645b9b6b8 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 27 Apr 2022 21:13:48 +0800 Subject: [PATCH 001/231] fix: AutoMigrate with special table name (#5301) * fix: AutoMigrate with special table name * test: migrate with special table name --- migrator/migrator.go | 3 ++- tests/migrate_test.go | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 93f4c5d0..d4989410 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -759,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} - if err := dep.Parse(value); err != nil { + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d6a6c4db..6576a2bd 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -636,3 +636,14 @@ func TestMigrateSerialColumn(t *testing.T) { AssertEqual(t, v.ID, v.UID) } } + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + DB.AutoMigrate(&Coupon{}) + DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} From d3488ae6bcee8ccbb1e463a42a048e1958c4c90f Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 30 Apr 2022 09:50:53 +0800 Subject: [PATCH 002/231] fix: add judge result of auto_migrate (#5306) Co-authored-by: fredhan --- tests/migrate_test.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 6576a2bd..28ee28cb 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -639,9 +639,19 @@ func TestMigrateSerialColumn(t *testing.T) { // https://github.com/go-gorm/gorm/issues/5300 func TestMigrateWithSpecialName(t *testing.T) { - DB.AutoMigrate(&Coupon{}) - DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) - DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } AssertEqual(t, true, DB.Migrator().HasTable("coupons")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) From b0104943edf50bba6072d18ca91e949ff8d4e3a2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 30 Apr 2022 09:57:16 +0800 Subject: [PATCH 003/231] fix: callbcak sort when using multiple plugin (#5304) --- callbacks.go | 8 +++++++- tests/callbacks_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index f344649e..c060ea70 100644 --- a/callbacks.go +++ b/callbacks.go @@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback func(*callback) error ) sort.Slice(cs, func(i, j int) bool { - return cs[j].before == "*" || cs[j].after == "*" + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false }) for _, c := range cs { diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 02765b8c..2bf9496b 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -38,6 +38,7 @@ func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} +func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { @@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) { } } } + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} From 19b8d37ae8155667d76021e4ca3314bb571756be Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 4 May 2022 18:57:53 +0800 Subject: [PATCH 004/231] fix: preload with skip hooks (#5310) --- callbacks/query.go | 2 +- tests/hooks_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index fb2bb37a..26ee8c34 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -252,7 +252,7 @@ func Preload(db *gorm.DB) { for _, name := range preloadNames { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 20e8dc18..8e964fd8 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -466,8 +466,9 @@ type Product4 struct { type ProductItem struct { gorm.Model - Code string - Product4ID uint + Code string + Product4ID uint + AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { @@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error { return nil } +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) @@ -498,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } } From 373bcf7aca01ef76c8ba5c3bc1ff191b020afc7b Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 9 May 2022 10:07:18 +0800 Subject: [PATCH 005/231] fix: many2many auto migrate (#5322) * fix: many2many auto migrate * fix: uuid ossp --- schema/relationship.go | 6 ++++-- schema/utils.go | 9 +++++++++ tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index b5100897..0aa33e51 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -235,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -258,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 2720c530..acf1a739 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "context" + "fmt" "reflect" "regexp" "strings" @@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct return tag } +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 28ee28cb..f862eda0 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -657,3 +657,39 @@ func TestMigrateWithSpecialName(t *testing.T) { AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +} From f5e77aab2fd3886f8743d6c9da87d5171f31a521 Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 17 May 2022 10:59:53 +0800 Subject: [PATCH 006/231] fix: quote index when creating table (#5331) --- migrator/migrator.go | 2 +- tests/migrate_test.go | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d4989410..757ab949 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -223,7 +223,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } createTableSQL += "," - values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index f862eda0..12eb8ed0 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -262,6 +262,25 @@ func TestMigrateTable(t *testing.T) { } } +func TestMigrateWithQuotedIndex(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type QuotedIndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index:AS"` // AS is one of MySQL reserved words + } + + if err := DB.Migrator().DropTable(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model From 7496c3a56eb4a26679a0a47db092e51379a98ff5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 17 May 2022 14:13:41 +0800 Subject: [PATCH 007/231] fix: trx in hooks clone stmt (#5338) * fix: trx in hooks * chore: format by gofumpt --- finisher_api.go | 3 +-- tests/transaction_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 663d532b..da4ef8f7 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -589,8 +589,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() } - - err = fc(db.Session(&Session{})) + err = fc(db.Session(&Session{NewDB: db.clone == 1})) } else { tx := db.Begin(opts...) if tx.Error != nil { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4e4b6149..0ac04a04 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -367,3 +367,33 @@ func TestTransactionOnClosedConn(t *testing.T) { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } + +func TestTransactionWithHooks(t *testing.T) { + user := GetUser("tTestTransactionWithHooks", Config{Account: true}) + DB.Create(&user) + + var err error + err = DB.Transaction(func(tx *gorm.DB) error { + return tx.Model(&User{}).Limit(1).Transaction(func(tx2 *gorm.DB) error { + return tx2.Scan(&User{}).Error + }) + }) + + if err != nil { + t.Error(err) + } + + // method with hooks + err = DB.Transaction(func(tx1 *gorm.DB) error { + // callMethod do + tx2 := tx1.Find(&User{}).Session(&gorm.Session{NewDB: true}) + // trx in hooks + return tx2.Transaction(func(tx3 *gorm.DB) error { + return tx3.Where("user_id", user.ID).Delete(&Account{}).Error + }) + }) + + if err != nil { + t.Error(err) + } +} From 540fb49bcbe07ee56c7a8a449a5504f40f50abc1 Mon Sep 17 00:00:00 2001 From: Clark McCauley Date: Sun, 22 May 2022 01:16:01 -0600 Subject: [PATCH 008/231] Fixed #5355 - Named variables don't work when followed by Windows CRLF line endings (#5356) * Fixed #5355. * Fixed unit test to test both CRLF and CR line endings --- clause/expression.go | 2 +- clause/expression_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/clause/expression.go b/clause/expression.go index dde00b1d..92ac7f22 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -127,7 +127,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) diff --git a/clause/expression_test.go b/clause/expression_test.go index 4826db38..aaede61c 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?;", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r\n AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r\n AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, From 7d1a92d60e7df38fdc2f3e42ff1cc7842aefdf18 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 22 May 2022 16:12:28 +0800 Subject: [PATCH 009/231] test: test for skip prepared when auto migrate (#5350) --- tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 11 ++++++++--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 12eb8ed0..2b5d7ecd 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -712,3 +713,38 @@ func TestPrimarykeyID(t *testing.T) { t.Fatalf("AutoMigrate err:%v", err) } } + +func TestInvalidCachedPlan(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + if err != nil { + t.Errorf("Open err:%v", err) + } + + type Object1 struct{} + type Object2 struct { + Field1 string + } + type Object3 struct { + Field2 string + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 08f4f193..dcba3cbf 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -17,6 +17,11 @@ import ( ) var DB *gorm.DB +var ( + mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" +) func init() { var err error @@ -49,13 +54,13 @@ func OpenTestConnection() (db *gorm.DB, err error) { case "mysql": log.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + dbDSN = mysqlDSN } db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) case "postgres": log.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + dbDSN = postgresDSN } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, @@ -72,7 +77,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { // GO log.Println("testing sqlserver...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + dbDSN = sqlserverDSN } db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) default: From 7e13b03bd4e57a554d3daa2774d3f58102ac30d9 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 28 May 2022 22:18:07 +0800 Subject: [PATCH 010/231] fix: duplicate column scan (#5369) * fix: duplicate column scan * fix: dup filed in inconsistent schema and database * chore[ci skip]: gofumpt style * chore[ci skip]: fix typo --- scan.go | 17 ++++++++++++----- tests/scan_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/scan.go b/scan.go index ad3734d8..a611a9ce 100644 --- a/scan.go +++ b/scan.go @@ -193,14 +193,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { + schFieldsCount := len(sch.Fields) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { if curIndex, ok := selectedColumnsMap[column]; ok { - for fieldIndex, selectField := range sch.Fields[curIndex+1:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 - fields[idx] = selectField - break + fields[idx] = field // handle duplicate fields + offset := curIndex + 1 + // handle sch inconsistent with database + // like Raw(`...`).Scan + if schFieldsCount > offset { + for fieldIndex, selectField := range sch.Fields[offset:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } } } } else { diff --git a/tests/scan_test.go b/tests/scan_test.go index 425c0a29..6f2e9f54 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -214,4 +214,29 @@ func TestScanToEmbedded(t *testing.T) { if !addressMatched { t.Errorf("Failed, no address matched") } + + personDupField := Person{ID: person1.ID} + if err := DB.Select("people.id, people.*"). + First(&personDupField).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + AssertEqual(t, person1, personDupField) + + user := User{ + Name: "TestScanToEmbedded_1", + Manager: &User{ + Name: "TestScanToEmbedded_1_m1", + Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"}, + }, + } + DB.Create(&user) + + type UserScan struct { + ID uint + Name string + ManagerID *uint + } + var user2 UserScan + err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error + AssertEqual(t, err, nil) } From dc1ae394f329340cb4475b037fe9f98bdbf7176d Mon Sep 17 00:00:00 2001 From: "t-inagaki@hum_op" Date: Sat, 28 May 2022 23:18:43 +0900 Subject: [PATCH 011/231] fixed FirstOrCreate not handled error when table is not exists (#5367) * fixed FirstOrCreate not handled error when table is not exists * delete useless part --- finisher_api.go | 4 ++-- tests/create_test.go | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index da4ef8f7..7a3f27ba 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -351,9 +351,9 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Model(dest).Updates(assigns) - } else { - tx.Error = result.Error } + } else { + tx.Error = result.Error } return tx } diff --git a/tests/create_test.go b/tests/create_test.go index 3730172f..274a7f48 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -476,6 +476,13 @@ func TestOmitWithCreate(t *testing.T) { CheckUser(t, result2, user2) } +func TestFirstOrCreateNotExistsTable(t *testing.T) { + company := Company{Name: "first_or_create_if_not_exists_table"} + if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil { + t.Errorf("not exists table, but err is nil") + } +} + func TestFirstOrCreateWithPrimaryKey(t *testing.T) { company := Company{ID: 100, Name: "company100_with_primarykey"} DB.FirstOrCreate(&company) From 93986de8e43bc9af6864621c9a4855f0f860cde2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 28 May 2022 23:09:13 +0800 Subject: [PATCH 012/231] fix: migrate column default value (#5359) Co-authored-by: Jinzhu --- migrator/migrator.go | 16 ++++- tests/migrate_test.go | 136 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 757ab949..4acc9df6 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -448,10 +448,20 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } // check default value - if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { - // not primary key - if !field.PrimaryKey { + if !field.PrimaryKey { + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && field.DefaultValueInterface == nil { + // defalut value -> null alterColumn = true + } else if !dvNotNull && field.DefaultValueInterface != nil { + // null -> default value + alterColumn = true + } else if dv != field.DefaultValue { + // default value not equal + // not both null + if !(field.DefaultValueInterface == nil && !dvNotNull) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2b5d7ecd..9e7caec9 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "math/rand" "reflect" "strings" @@ -714,6 +715,141 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestUniqueColumn(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + type UniqueTest struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique"` + } + + type UniqueTest2 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:NULL"` + } + + type UniqueTest3 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:''"` + } + + type UniqueTest4 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:'123'"` + } + + var err error + err = DB.Migrator().DropTable(&UniqueTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // null -> null + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok := ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // not trigger alert column + AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> empty string + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, true, ok) + + // empty string -> 123 + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "123", value) + AssertEqual(t, true, ok) + + // 123 -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + +} + +func findColumnType(dest interface{}, columnName string) ( + foundColumn gorm.ColumnType, err error) { + columnTypes, err := DB.Migrator().ColumnTypes(dest) + if err != nil { + err = fmt.Errorf("ColumnTypes err:%v", err) + return + } + + for _, c := range columnTypes { + if c.Name() == columnName { + foundColumn = c + break + } + } + return +} + func TestInvalidCachedPlan(t *testing.T) { if DB.Dialector.Name() != "postgres" { return From f4e9904b02dab5c2f675d9c661ae1c1a8654a768 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 1 Jun 2022 10:26:09 +0800 Subject: [PATCH 013/231] chore(deps): bump gorm.io/driver/mysql from 1.3.3 to 1.3.4 in /tests (#5385) Bumps [gorm.io/driver/mysql](https://github.com/go-gorm/mysql) from 1.3.3 to 1.3.4. - [Release notes](https://github.com/go-gorm/mysql/releases) - [Commits](https://github.com/go-gorm/mysql/compare/v1.3.3...v1.3.4) --- updated-dependencies: - dependency-name: gorm.io/driver/mysql dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 6a2cf22f..bd668420 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect - gorm.io/driver/mysql v1.3.3 + gorm.io/driver/mysql v1.3.4 gorm.io/driver/postgres v1.3.5 gorm.io/driver/sqlite v1.3.2 gorm.io/driver/sqlserver v1.3.2 From d01de7232b46987e239ef19a89d9ab192f453894 Mon Sep 17 00:00:00 2001 From: Bexanderthebex Date: Wed, 1 Jun 2022 11:50:57 +0800 Subject: [PATCH 014/231] enhancement: Avoid calling reflect.New() when passing in slice of values to `Scan()` (#5388) * fix: reduce allocations when slice of values * chore[test]: Add benchmark for scan * chore[test]: add bench for scan slice * chore[test]: add bench for slice pointer and improve tests * chore[test]: make sure database is empty when doing slice tests * fix[test]: correct sql delete statement * enhancement: skip new if rows affected = 0 --- scan.go | 7 ++++++- tests/benchmark_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index a611a9ce..1bb51560 100644 --- a/scan.go +++ b/scan.go @@ -237,6 +237,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var elem reflect.Value + recyclableStruct := reflect.New(reflectValueType) if !update || reflectValue.Len() == 0 { update = false @@ -261,7 +262,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } } } else { - elem = reflect.New(reflectValueType) + if isPtr && db.RowsAffected > 0 { + elem = reflect.New(reflectValueType) + } else { + elem = recyclableStruct + } } db.scanIntoStruct(rows, elem, values, fields, joinFields) diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go index d897a634..22d15898 100644 --- a/tests/benchmark_test.go +++ b/tests/benchmark_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "testing" . "gorm.io/gorm/utils/tests" @@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) { } } +func BenchmarkScan(b *testing.B) { + user := *GetUser("scan", Config{}) + DB.Create(&user) + + var u User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users where id = ?", user.ID).Scan(&u) + } +} + +func BenchmarkScanSlice(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + +func BenchmarkScanSlicePointer(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []*User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + func BenchmarkUpdate(b *testing.B) { user := *GetUser("find", Config{}) DB.Create(&user) From 8d457146283e0a4197c26a559bedb1938767b78e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 14 Jun 2022 13:48:50 +0800 Subject: [PATCH 015/231] fix: reset null value in slice (#5417) * fix: reset null value in slice * fix: can not set field in-place in join --- scan.go | 17 ++++++---- schema/field.go | 10 ++++++ tests/query_test.go | 77 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index 1bb51560..6250fb57 100644 --- a/scan.go +++ b/scan.go @@ -66,18 +66,23 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) + joinedSchemaMap := make(map[*schema.Field]interface{}, 0) for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { - relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } } db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } diff --git a/schema/field.go b/schema/field.go index d6df6596..981f56f2 100644 --- a/schema/field.go +++ b/schema/field.go @@ -587,6 +587,8 @@ func (field *Field) setupValuerAndSetter() { case **bool: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetBool(**data) + } else { + field.ReflectValueOf(ctx, value).SetBool(false) } case bool: field.ReflectValueOf(ctx, value).SetBool(data) @@ -606,6 +608,8 @@ func (field *Field) setupValuerAndSetter() { case **int64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) + } else { + field.ReflectValueOf(ctx, value).SetInt(0) } case int64: field.ReflectValueOf(ctx, value).SetInt(data) @@ -670,6 +674,8 @@ func (field *Field) setupValuerAndSetter() { case **uint64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) + } else { + field.ReflectValueOf(ctx, value).SetUint(0) } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) @@ -722,6 +728,8 @@ func (field *Field) setupValuerAndSetter() { case **float64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) + } else { + field.ReflectValueOf(ctx, value).SetFloat(0) } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) @@ -766,6 +774,8 @@ func (field *Field) setupValuerAndSetter() { case **string: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) + } else { + field.ReflectValueOf(ctx, value).SetString("") } case string: field.ReflectValueOf(ctx, value).SetString(data) diff --git a/tests/query_test.go b/tests/query_test.go index f66cf83a..253d8409 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1258,3 +1258,80 @@ func TestQueryScannerWithSingleColumn(t *testing.T) { AssertEqual(t, result2.data, 20) } + +func TestQueryResetNullValue(t *testing.T) { + type QueryResetItem struct { + ID string `gorm:"type:varchar(5)"` + Name string + } + + type QueryResetNullValue struct { + ID int + Name string `gorm:"default:NULL"` + Flag bool `gorm:"default:NULL"` + Number1 int64 `gorm:"default:NULL"` + Number2 uint64 `gorm:"default:NULL"` + Number3 float64 `gorm:"default:NULL"` + Now *time.Time `gorm:"defalut:NULL"` + Item1Id string + Item1 *QueryResetItem `gorm:"references:ID"` + Item2Id string + Item2 *QueryResetItem `gorm:"references:ID"` + } + + DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) + DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) + + now := time.Now() + q1 := QueryResetNullValue{ + Name: "name", + Flag: true, + Number1: 100, + Number2: 200, + Number3: 300.1, + Now: &now, + Item1: &QueryResetItem{ + ID: "u_1_1", + Name: "item_1_1", + }, + Item2: &QueryResetItem{ + ID: "u_1_2", + Name: "item_1_2", + }, + } + + q2 := QueryResetNullValue{ + Item1: &QueryResetItem{ + ID: "u_2_1", + Name: "item_2_1", + }, + Item2: &QueryResetItem{ + ID: "u_2_2", + Name: "item_2_2", + }, + } + + var err error + err = DB.Create(&q1).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + err = DB.Create(&q2).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + var qs []QueryResetNullValue + err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error + if err != nil { + t.Errorf("failed to find:%v", err) + } + + if len(qs) != 2 { + t.Fatalf("find count not equal:%d", len(qs)) + } + + AssertEqual(t, q1, qs[0]) + AssertEqual(t, q2, qs[1]) +} From 1305f637f834baa13c514df915157a51d86b4f28 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Fri, 17 Jun 2022 11:00:57 +0800 Subject: [PATCH 016/231] feat: add method GetIndexes (#5436) * feat: add method GetIndexes * feat: add default impl for Index interface * feat: fmt --- migrator.go | 10 ++++++++++ migrator/index.go | 43 +++++++++++++++++++++++++++++++++++++++++++ migrator/migrator.go | 6 ++++++ 3 files changed, 59 insertions(+) create mode 100644 migrator/index.go diff --git a/migrator.go b/migrator.go index 52443877..34e888f2 100644 --- a/migrator.go +++ b/migrator.go @@ -51,6 +51,15 @@ type ColumnType interface { DefaultValue() (value string, ok bool) } +type Index interface { + Table() string + Name() string + Columns() []string + PrimaryKey() (isPrimaryKey bool, ok bool) + Unique() (unique bool, ok bool) + Option() string +} + // Migrator migrator interface type Migrator interface { // AutoMigrate @@ -90,4 +99,5 @@ type Migrator interface { DropIndex(dst interface{}, name string) error HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error + GetIndexes(dst interface{}) ([]Index, error) } diff --git a/migrator/index.go b/migrator/index.go new file mode 100644 index 00000000..fe686e5a --- /dev/null +++ b/migrator/index.go @@ -0,0 +1,43 @@ +package migrator + +import "database/sql" + +// Index implements gorm.Index interface +type Index struct { + TableName string + NameValue string + ColumnList []string + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + OptionValue string +} + +// Table return the table name of the index. +func (idx Index) Table() string { + return idx.TableName +} + +// Name return the name of the index. +func (idx Index) Name() string { + return idx.NameValue +} + +// Columns return the columns fo the index +func (idx Index) Columns() []string { + return idx.ColumnList +} + +// PrimaryKey returns the index is primary key or not. +func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { + return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid +} + +// Unique returns whether the index is unique or not. +func (idx Index) Unique() (unique bool, ok bool) { + return idx.UniqueValue.Bool, idx.UniqueValue.Valid +} + +// Option return the optional attribute fo the index +func (idx Index) Option() string { + return idx.OptionValue +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 4acc9df6..f20bf513 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -3,6 +3,7 @@ package migrator import ( "context" "database/sql" + "errors" "fmt" "reflect" "regexp" @@ -854,3 +855,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { } return clause.Table{Name: stmt.Table} } + +// GetIndexes return Indexes []gorm.Index and execErr error +func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { + return nil, errors.New("not support") +} From a70af2a4c0d7bd66d76999f142a9babb438e53d7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 20 Jun 2022 15:35:29 +0800 Subject: [PATCH 017/231] Fix Select with digits in column name --- statement.go | 2 +- statement_test.go | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index ed3e8716..850af6cb 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index 3f099d61..a89cc7d2 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,10 +37,14 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string]string{ - "table.name": "name", - "`table`.`name`": "name", - "'table'.'name'": "name", - "'table'.name": "name", + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + "table1.name_23": "name_23", + "`table_1`.`name23`": "name23", + "'table23'.'name_1'": "name_1", + "'table23'.name1": "name1", } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From 93f28bc116526ba4decdd969a7b2b0b245ad70f1 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 24 Jun 2022 10:33:39 +0800 Subject: [PATCH 018/231] use callback to handle transaction - make transaction have before and after hooks, so plugin can have hack before or after transaction --- callbacks.go | 37 +++++++++++++++++++++++++++++++------ finisher_api.go | 16 +--------------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/callbacks.go b/callbacks.go index c060ea70..1b4e58ea 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "errors" "fmt" "reflect" @@ -15,12 +16,13 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + "transaction": {db: db}, }, } } @@ -72,6 +74,29 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } +func (cs *callbacks) Transaction() *processor { + return cs.processors["transaction"] +} + +func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { + var err error + + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { diff --git a/finisher_api.go b/finisher_api.go index 7a3f27ba..3e406c1c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -619,27 +619,13 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions - err error ) if len(opts) > 0 { opt = opts[0] } - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - tx.AddError(err) - } - - return tx + return tx.callbacks.Transaction().Begin(tx, opt) } // Commit commit a transaction From 3e6ab990431c48a816676c9efbe1d0952ffb4a28 Mon Sep 17 00:00:00 2001 From: wws <32982278+wuweishuo@users.noreply.github.com> Date: Sat, 25 Jun 2022 16:32:47 +0800 Subject: [PATCH 019/231] fix:serializer contain field panic (#5461) --- schema/field.go | 2 +- tests/serializer_test.go | 43 +++++++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/schema/field.go b/schema/field.go index 981f56f2..d4dfbd6f 100644 --- a/schema/field.go +++ b/schema/field.go @@ -950,7 +950,7 @@ func (field *Field) setupNewValuePool() { New: func() interface{} { return &serializer{ Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + Serializer: field.Serializer, } }, } diff --git a/tests/serializer_test.go b/tests/serializer_test.go index ee14841a..80e015ff 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -16,13 +16,14 @@ import ( type SerializerStruct struct { gorm.Model - Name []byte `gorm:"json"` - Roles Roles `gorm:"serializer:json"` - Contracts map[string]interface{} `gorm:"serializer:json"` - JobInfo Job `gorm:"type:bytes;serializer:gob"` - CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - EncryptedString EncryptedString + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString } type Roles []string @@ -52,7 +53,32 @@ func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst re return "hello" + string(es), nil } +type CustomSerializer struct { + prefix []byte +} + +func NewCustomSerializer(prefix string) *CustomSerializer { + return &CustomSerializer{prefix: []byte(prefix)} +} + +func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix)) + case string: + err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix))) + default: + err = fmt.Errorf("unsupported data %#v", dbValue) + } + return err +} + +func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil +} + func TestSerializer(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(&SerializerStruct{}) if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -74,6 +100,7 @@ func TestSerializer(t *testing.T) { Location: "Kenmawr", IsIntern: false, }, + CustomSerializerString: "world", } if err := DB.Create(&data).Error; err != nil { @@ -90,6 +117,7 @@ func TestSerializer(t *testing.T) { } func TestSerializerAssignFirstOrCreate(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(&SerializerStruct{}) if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) @@ -109,6 +137,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { Location: "Shadyside", IsIntern: false, }, + CustomSerializerString: "world", } // first time insert record From 235c093bb97d37cdfa34103b59eabacfde9b2a42 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 29 Jun 2022 10:07:42 +0800 Subject: [PATCH 020/231] fix(MigrateColumn):declared different type without length (#5465) --- migrator/migrator.go | 11 +++++++---- tests/migrate_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index f20bf513..87ac7745 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,7 +15,6 @@ import ( ) var ( - regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) ) @@ -404,11 +403,16 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate - fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) alterColumn := false + // check type + if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { + alterColumn = true + } + // check size if length, ok := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { @@ -416,9 +420,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else { // has size in data type and not equal // Since the following code is frequently called in the for loop, reg optimization is needed here - matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && + if !field.PrimaryKey && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { alterColumn = true } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 9e7caec9..0bbef382 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -884,3 +884,42 @@ func TestInvalidCachedPlan(t *testing.T) { t.Errorf("AutoMigrate err:%v", err) } } + +func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { + type DiffType struct { + ID uint + Name string `gorm:"type:varchar(20)"` + } + + type DiffType1 struct { + ID uint + Name string `gorm:"type:text"` + } + + var err error + DB.Migrator().DropTable(&DiffType{}) + + err = DB.AutoMigrate(&DiffType{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) + + err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) +} From 2cb4088456eaa845d6e89eeb69fb57d565a72cc2 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 1 Jul 2022 14:37:38 +0800 Subject: [PATCH 021/231] ignore AddError return error --- callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index 1b4e58ea..f835e504 100644 --- a/callbacks.go +++ b/callbacks.go @@ -91,7 +91,7 @@ func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { } if err != nil { - tx.AddError(err) + _ = tx.AddError(err) } return tx From c74bc57add435a4fa0de1cd0eb65f11f62fe1dfd Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 1 Jul 2022 15:12:15 +0800 Subject: [PATCH 022/231] fix: association many2many duplicate elem (#5473) * fix: association many2many duplicate elem * chore: gofumpt style --- callbacks/associations.go | 29 ++++++++++++++++++++-------- tests/associations_many2many_test.go | 27 ++++++++++++++++++++++++++ tests/migrate_test.go | 4 ++-- tests/serializer_test.go | 3 +-- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index fd3141cf..4a50e6c2 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -253,6 +253,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} @@ -272,19 +273,31 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joins = reflect.Append(joins, joinValue) } + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { elem := f.Index(i) - - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + if !isPtr { + elem = elem.Addr() } + objs = append(objs, v) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + identityMap[cacheKey] = true + distinctElems = reflect.Append(distinctElems, elem) + } + } } } @@ -304,7 +317,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems, selectColumns, restricted, nil) + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 28b441bd..7b45befb 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -324,3 +325,29 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } + +func TestDuplicateMany2ManyAssociation(t *testing.T) { + user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-2"}, + }} + + user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-3"}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0bbef382..3d6a7858 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -830,11 +830,11 @@ func TestUniqueColumn(t *testing.T) { value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) - } func findColumnType(dest interface{}, columnName string) ( - foundColumn gorm.ColumnType, err error) { + foundColumn gorm.ColumnType, err error, +) { columnTypes, err := DB.Migrator().ColumnTypes(dest) if err != nil { err = fmt.Errorf("ColumnTypes err:%v", err) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 80e015ff..7232f9df 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -113,7 +113,6 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) - } func TestSerializerAssignFirstOrCreate(t *testing.T) { @@ -152,7 +151,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { } AssertEqual(t, result, out) - //update record + // update record data.Roles = append(data.Roles, "r3") data.JobInfo.Location = "Gates Hillman Complex" if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { From 46bce170cae701615e2b2f8b2448b54524be9648 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 4 Jul 2022 16:42:27 +0800 Subject: [PATCH 023/231] test: pg array type (#5480) --- tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 3d6a7858..0b5bc5eb 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -923,3 +923,39 @@ func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) } + +func TestMigrateArrayTypeModel(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ArrayTypeModel struct { + ID uint + Number string `gorm:"type:varchar(51);NOT NULL"` + TextArray []string `gorm:"type:text[];NOT NULL"` + NestedTextArray [][]string `gorm:"type:text[][]"` + NestedIntArray [][]int64 `gorm:"type:integer[3][3]"` + } + + var err error + DB.Migrator().DropTable(&ArrayTypeModel{}) + + err = DB.AutoMigrate(&ArrayTypeModel{}) + AssertEqual(t, nil, err) + + ct, err := findColumnType(&ArrayTypeModel{}, "number") + AssertEqual(t, nil, err) + AssertEqual(t, "varchar", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_int_array") + AssertEqual(t, nil, err) + AssertEqual(t, "integer[]", ct.DatabaseTypeName()) +} From fe01e1b9f43070e3814817b4b762dfd08a3ced30 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 7 Jul 2022 14:43:33 +0800 Subject: [PATCH 024/231] Fix Model with slice data --- callbacks/update.go | 2 +- tests/go.mod | 12 +++++++----- tests/update_test.go | 8 ++++++++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 01f40509..42ffe2f6 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -172,7 +172,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.And(clause.Or(primaryKeyExprs...))}}) } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { diff --git a/tests/go.mod b/tests/go.mod index bd668420..f3e9d260 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,16 +3,18 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.5 - golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect + github.com/lib/pq v1.10.6 + github.com/mattn/go-sqlite3 v1.14.14 // indirect + golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect gorm.io/driver/mysql v1.3.4 - gorm.io/driver/postgres v1.3.5 - gorm.io/driver/sqlite v1.3.2 + gorm.io/driver/postgres v1.3.8 + gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.4 + gorm.io/gorm v1.23.7 ) replace gorm.io/gorm => ../ diff --git a/tests/update_test.go b/tests/update_test.go index 41ea5d27..0fc89a93 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -122,6 +122,14 @@ func TestUpdate(t *testing.T) { } else { CheckUser(t, result4, *user) } + + if rowsAffected := DB.Model([]User{result4}).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + if rowsAffected := DB.Model(users).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 3 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } } func TestUpdates(t *testing.T) { From 9fd73ae4f1f638e4c49ae4e6fab8beb9863adabc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 7 Jul 2022 15:06:48 +0800 Subject: [PATCH 025/231] Revert "use callback to handle transaction" This reverts commit 93f28bc116526ba4decdd969a7b2b0b245ad70f1. --- callbacks.go | 37 ++++++------------------------------- finisher_api.go | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/callbacks.go b/callbacks.go index f835e504..c060ea70 100644 --- a/callbacks.go +++ b/callbacks.go @@ -2,7 +2,6 @@ package gorm import ( "context" - "database/sql" "errors" "fmt" "reflect" @@ -16,13 +15,12 @@ import ( func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ - "create": {db: db}, - "query": {db: db}, - "update": {db: db}, - "delete": {db: db}, - "row": {db: db}, - "raw": {db: db}, - "transaction": {db: db}, + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, }, } } @@ -74,29 +72,6 @@ func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } -func (cs *callbacks) Transaction() *processor { - return cs.processors["transaction"] -} - -func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB { - var err error - - switch beginner := tx.Statement.ConnPool.(type) { - case TxBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - case ConnPoolBeginner: - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - default: - err = ErrInvalidTransaction - } - - if err != nil { - _ = tx.AddError(err) - } - - return tx -} - func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { diff --git a/finisher_api.go b/finisher_api.go index 3e406c1c..7a3f27ba 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -619,13 +619,27 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions + err error ) if len(opts) > 0 { opt = opts[0] } - return tx.callbacks.Transaction().Begin(tx, opt) + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + case ConnPoolBeginner: + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + default: + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx } // Commit commit a transaction From b13d1757fab7093d769afc02573ee3c359faeb26 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 7 Jul 2022 15:39:29 +0800 Subject: [PATCH 026/231] Refactor Model with slice data --- callbacks/update.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 42ffe2f6..48c61bf4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -158,21 +158,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if size := stmt.ReflectValue.Len(); size > 0 { - var primaryKeyExprs []clause.Expression + var isZero bool for i := 0; i < size; i++ { - exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + for _, field := range stmt.Schema.PrimaryFields { + _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) + if !isZero { + break + } } } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.And(clause.Or(primaryKeyExprs...))}}) + if !isZero { + _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { From 62fdc2bb3b4f991a8ed1ec2fdb47571a64fd18ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Jul 2022 11:51:05 +0800 Subject: [PATCH 027/231] Fix serializer with empty string --- schema/serializer.go | 10 +++++++--- tests/go.mod | 4 ++-- tests/serializer_test.go | 8 ++++++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 758a6421..21be3c35 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -88,7 +88,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } - err = json.Unmarshal(bytes, fieldValue.Interface()) + if len(bytes) > 0 { + err = json.Unmarshal(bytes, fieldValue.Interface()) + } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) @@ -142,8 +144,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, default: return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) } - decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) - err = decoder.Decode(fieldValue.Interface()) + if len(bytesValue) > 0 { + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) return diff --git a/tests/go.mod b/tests/go.mod index f3e9d260..7a788a43 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -10,11 +10,11 @@ require ( github.com/lib/pq v1.10.6 github.com/mattn/go-sqlite3 v1.14.14 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.4 + gorm.io/driver/mysql v1.3.5 gorm.io/driver/postgres v1.3.8 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.7 + gorm.io/gorm v1.23.8 ) replace gorm.io/gorm => ../ diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 7232f9df..95d25699 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -113,6 +113,14 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) + + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } } func TestSerializerAssignFirstOrCreate(t *testing.T) { From 08f6d06e47b2ee6285577d726c59e5e2c3ff99ac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 13 Jul 2022 17:21:19 +0800 Subject: [PATCH 028/231] Fix select with quoted column name --- statement.go | 2 +- statement_test.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 850af6cb..79e29915 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:[\W]?(?:[a-z_0-9]+?)[\W]?\.)?[\W]?([a-z_0-9]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index a89cc7d2..4432cda4 100644 --- a/statement_test.go +++ b/statement_test.go @@ -45,6 +45,8 @@ func TestNameMatcher(t *testing.T) { "`table_1`.`name23`": "name23", "'table23'.'name_1'": "name_1", "'table23'.name1": "name1", + "'name1'": "name1", + "`name_1`": "name_1", } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From a7063848efe743166ad9fae460e8c2acc1b14a6d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 13 Jul 2022 17:44:14 +0800 Subject: [PATCH 029/231] Fix select with uppercase column name --- statement.go | 2 +- statement_test.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 79e29915..aa5c2993 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?(?:[a-z_0-9]+?)[\W]?\.)?[\W]?([a-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:[\W]?(?:[A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index 4432cda4..19ab38f7 100644 --- a/statement_test.go +++ b/statement_test.go @@ -47,6 +47,8 @@ func TestNameMatcher(t *testing.T) { "'table23'.name1": "name1", "'name1'": "name1", "`name_1`": "name_1", + "`Name_1`": "Name_1", + "`Table`.`nAme`": "nAme", } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From cae30e9a50cb9260b805310062059853927d488c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 13 Jul 2022 18:02:11 +0800 Subject: [PATCH 030/231] Fix select with association column --- statement.go | 6 +++--- statement_test.go | 28 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/statement.go b/statement.go index aa5c2993..9a621179 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?(?:[A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { @@ -672,8 +672,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { - results[matches[1]] = true + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && matches[1] == stmt.Table { + results[matches[2]] = true } else { results[column] = true } diff --git a/statement_test.go b/statement_test.go index 19ab38f7..a537c7be 100644 --- a/statement_test.go +++ b/statement_test.go @@ -36,21 +36,21 @@ func TestWhereCloneCorruption(t *testing.T) { } func TestNameMatcher(t *testing.T) { - for k, v := range map[string]string{ - "table.name": "name", - "`table`.`name`": "name", - "'table'.'name'": "name", - "'table'.name": "name", - "table1.name_23": "name_23", - "`table_1`.`name23`": "name23", - "'table23'.'name_1'": "name_1", - "'table23'.name1": "name1", - "'name1'": "name1", - "`name_1`": "name_1", - "`Name_1`": "Name_1", - "`Table`.`nAme`": "nAme", + for k, v := range map[string][]string{ + "table.name": []string{"table", "name"}, + "`table`.`name`": []string{"table", "name"}, + "'table'.'name'": []string{"table", "name"}, + "'table'.name": []string{"table", "name"}, + "table1.name_23": []string{"table1", "name_23"}, + "`table_1`.`name23`": []string{"table_1", "name23"}, + "'table23'.'name_1'": []string{"table23", "name_1"}, + "'table23'.name1": []string{"table23", "name1"}, + "'name1'": []string{"", "name1"}, + "`name_1`": []string{"", "name_1"}, + "`Name_1`": []string{"", "Name_1"}, + "`Table`.`nAme`": []string{"Table", "nAme"}, } { - if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) } } From 3262daf8d46818395a7b01778e8f813afc0dc3d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 13 Jul 2022 18:26:35 +0800 Subject: [PATCH 031/231] Fix select with association column --- statement.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statement.go b/statement.go index 9a621179..12687810 100644 --- a/statement.go +++ b/statement.go @@ -672,7 +672,7 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && matches[1] == stmt.Table { + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { results[matches[2]] = true } else { results[column] = true From 4d40e34734289137d9ca8fc2b69bf8de98a7448c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Jul 2022 14:39:43 +0800 Subject: [PATCH 032/231] Update select tests --- tests/helper_test.go | 2 ++ tests/update_belongs_to_test.go | 15 +++++++++++++++ tests/update_has_one_test.go | 10 +++++++--- tests/update_test.go | 2 ++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/helper_test.go b/tests/helper_test.go index 7ee2a576..d1af0739 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -80,6 +80,7 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") } } @@ -174,6 +175,7 @@ func CheckUser(t *testing.T, user User, expect User) { var manager User DB.First(&manager, "id = ?", *user.ManagerID) AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 8fe0f289..4e94cfd5 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -41,4 +41,19 @@ func TestUpdateBelongsTo(t *testing.T) { var user4 User DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) + + user.Company.Name += "new2" + user.Manager.Name += "new2" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user5 User + DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID) + if user5.Manager.Name != user4.Manager.Name { + t.Errorf("should not update user's manager") + } else { + user.Manager.Name = user4.Manager.Name + } + CheckUser(t, user, user5) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index c926fbcf..40af6ae7 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -90,8 +90,9 @@ func TestUpdateHasOne(t *testing.T) { t.Run("Restriction", func(t *testing.T) { type CustomizeAccount struct { gorm.Model - UserID sql.NullInt64 - Number string `gorm:"<-:create"` + UserID sql.NullInt64 + Number string `gorm:"<-:create"` + Number2 string } type CustomizeUser struct { @@ -114,7 +115,8 @@ func TestUpdateHasOne(t *testing.T) { cusUser := CustomizeUser{ Name: "update-has-one-associations", Account: CustomizeAccount{ - Number: number, + Number: number, + Number2: number, }, } @@ -122,6 +124,7 @@ func TestUpdateHasOne(t *testing.T) { t.Fatalf("errors happened when create: %v", err) } cusUser.Account.Number += "-update" + cusUser.Account.Number2 += "-update" if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } @@ -129,5 +132,6 @@ func TestUpdateHasOne(t *testing.T) { var account2 CustomizeAccount DB.Find(&account2, "user_id = ?", cusUser.ID) AssertEqual(t, account2.Number, number) + AssertEqual(t, account2.Number2, cusUser.Account.Number2) }) } diff --git a/tests/update_test.go b/tests/update_test.go index 0fc89a93..d7634580 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -307,6 +307,8 @@ func TestSelectWithUpdate(t *testing.T) { if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) } + + AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age") } func TestSelectWithUpdateWithMap(t *testing.T) { From 099813bf11dc1c4e614d73daee5766f4963136cf Mon Sep 17 00:00:00 2001 From: alingse Date: Thu, 14 Jul 2022 20:05:22 +0800 Subject: [PATCH 033/231] Adjust ToStringKey use unpack params, fix pass []any as any in variadic function (#5500) * fix pass []any as any in variadic function * add .vscode to gitignore --- .gitignore | 3 ++- callbacks/associations.go | 4 ++-- utils/utils_test.go | 17 +++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 45505cc9..72733326 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ documents coverage.txt _book .idea -vendor \ No newline at end of file +vendor +.vscode diff --git a/callbacks/associations.go b/callbacks/associations.go index 4a50e6c2..00e00fcc 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } - cacheKey := utils.ToStringKey(relPrimaryValues) + cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { identityMap[cacheKey] = true if isPtr { @@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } - cacheKey := utils.ToStringKey(relPrimaryValues) + cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { identityMap[cacheKey] = true distinctElems = reflect.Append(distinctElems, elem) diff --git a/utils/utils_test.go b/utils/utils_test.go index 5737c511..27dfee16 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -12,3 +12,20 @@ func TestIsValidDBNameChar(t *testing.T) { } } } + +func TestToStringKey(t *testing.T) { + cases := []struct { + values []interface{} + key string + }{ + {[]interface{}{"a"}, "a"}, + {[]interface{}{1, 2, 3}, "1_2_3"}, + {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, + } + for _, c := range cases { + if key := ToStringKey(c.values...); key != c.key { + t.Errorf("%v: expected %v, got %v", c.values, c.key, key) + } + } +} From 2ba599e8b7d2197739669970fa88d591423f0cae Mon Sep 17 00:00:00 2001 From: Goxiaoy Date: Fri, 15 Jul 2022 11:15:18 +0800 Subject: [PATCH 034/231] fix empty QueryClauses in association (#5502) (#5503) * fix empty QueryClauses in association (#5502) * test: empty QueryClauses in association (#5502) * style: empty QueryClauses in association (#5502) * style: empty QueryClauses in association (#5502) --- association.go | 4 ++- tests/associations_test.go | 64 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/association.go b/association.go index 35e10ddd..06229caa 100644 --- a/association.go +++ b/association.go @@ -507,7 +507,9 @@ func (association *Association) buildCondition() *DB { joinStmt.AddClause(queryClause) } joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + if len(joinStmt.SQL.String()) > 0 { + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } } tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ diff --git a/tests/associations_test.go b/tests/associations_test.go index e729e979..42b32afc 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -4,6 +4,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -284,3 +286,65 @@ func TestAssociationError(t *testing.T) { err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) } + +type ( + myType string + emptyQueryClause struct { + Field *schema.Field + } +) + +func (myType) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{emptyQueryClause{Field: f}} +} + +func (sd emptyQueryClause) Name() string { + return "empty" +} + +func (sd emptyQueryClause) Build(clause.Builder) { +} + +func (sd emptyQueryClause) MergeClause(*clause.Clause) { +} + +func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) { + // do nothing +} + +func TestAssociationEmptyQueryClause(t *testing.T) { + type Organization struct { + gorm.Model + Name string + } + type Region struct { + gorm.Model + Name string + Organizations []Organization `gorm:"many2many:region_orgs;"` + } + type RegionOrg struct { + RegionId uint + OrganizationId uint + Empty myType + } + if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil { + t.Fatalf("Failed to set up join table, got error: %s", err) + } + if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %s", err) + } + if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + region := &Region{Name: "Region1"} + if err := DB.Create(region).Error; err != nil { + t.Fatalf("fail to create region %v", err) + } + var orgs []Organization + + if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil { + t.Fatalf("fail to find region organizations %v", err) + } else { + AssertEqual(t, len(orgs), 0) + } +} From 75720099b5540a38fa9f7c26d8237df2cd1570a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Jul 2022 18:06:45 +0800 Subject: [PATCH 035/231] Create a new db in FindInBatches --- finisher_api.go | 4 +++- gorm.go | 3 ++- tests/query_test.go | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 7a3f27ba..af9afb63 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -202,7 +202,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch++ if result.Error == nil && result.RowsAffected != 0 { - tx.AddError(fc(result, batch)) + fcTx := result.Session(&Session{NewDB: true}) + fcTx.RowsAffected = result.RowsAffected + tx.AddError(fc(fcTx, batch)) } else if result.Error != nil { tx.AddError(result.Error) } diff --git a/gorm.go b/gorm.go index 6a6bb032..c852e60c 100644 --- a/gorm.go +++ b/gorm.go @@ -300,7 +300,8 @@ func (db *DB) WithContext(ctx context.Context) *DB { // Debug start debug mode func (db *DB) Debug() (tx *DB) { - return db.Session(&Session{ + tx = db.getInstance() + return tx.Session(&Session{ Logger: db.Logger.LogMode(logger.Info), }) } diff --git a/tests/query_test.go b/tests/query_test.go index 253d8409..4569fe1a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -257,7 +257,7 @@ func TestFindInBatches(t *testing.T) { totalBatch int ) - if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch if tx.RowsAffected != 2 { @@ -273,7 +273,7 @@ func TestFindInBatches(t *testing.T) { } if err := tx.Save(results).Error; err != nil { - t.Errorf("failed to save users, got error %v", err) + t.Fatalf("failed to save users, got error %v", err) } return nil From bab3cd1724cb111961d931f514e1bda316de8572 Mon Sep 17 00:00:00 2001 From: Xudong Zhang Date: Mon, 18 Jul 2022 20:47:00 +0800 Subject: [PATCH 036/231] fix bad logging performance of bulk create (#5520) (#5521) --- logger/sql.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index c8b194c3..bcacc7cf 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,6 +30,8 @@ func isPrintable(s string) bool { var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -138,9 +140,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") - for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) - } + + sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { + num := v[1 : len(v)-1] + n, _ := strconv.Atoi(num) + + // position var start from 1 ($1, $2) + n -= 1 + if n >= 0 && n <= len(vars)-1 { + return vars[n] + } + return v + }) } return sql From 06e174e24ddc3a49716ccd877aac221ca2469331 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 25 Jul 2022 14:10:30 +0800 Subject: [PATCH 037/231] fix: embedded default value (#5540) --- schema/field.go | 8 ++------ tests/embedded_struct_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/schema/field.go b/schema/field.go index d4dfbd6f..47f3994f 100644 --- a/schema/field.go +++ b/schema/field.go @@ -403,18 +403,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if ef.PrimaryKey { - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { + if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) { ef.PrimaryKey = false if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { ef.AutoIncrement = false } - if ef.DefaultValue == "" { + if !ef.AutoIncrement && ef.DefaultValue == "" { ef.HasDefaultValue = false } } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 312a5c37..e309d06c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -168,3 +168,29 @@ func TestEmbeddedRelations(t *testing.T) { } } } + +func TestEmbeddedTagSetting(t *testing.T) { + type Tag1 struct { + Id int64 `gorm:"autoIncrement"` + } + type Tag2 struct { + Id int64 + } + + type EmbeddedTag struct { + Tag1 Tag1 `gorm:"Embedded;"` + Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"` + Name string + } + + DB.Migrator().DropTable(&EmbeddedTag{}) + err := DB.Migrator().AutoMigrate(&EmbeddedTag{}) + AssertEqual(t, err, nil) + + t1 := EmbeddedTag{Name: "embedded_tag"} + err = DB.Save(&t1).Error + AssertEqual(t, err, nil) + if t1.Tag1.Id == 0 { + t.Errorf("embedded struct's primary field should be rewrited") + } +} From 3c6eb14c92679e34cd49de53ef0b3d327f4dd06a Mon Sep 17 00:00:00 2001 From: MJrocker <1725014728@qq.com> Date: Tue, 26 Jul 2022 20:01:20 +0800 Subject: [PATCH 038/231] Fixed some typos in the code comment --- schema/schema.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index eca113e9..3791237d 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -112,7 +112,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schemaCacheKey = modelType } - // Load exist schmema cache, return if exists + // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete @@ -146,7 +146,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - // Load exist schmema cache, return if exists + // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete From 6e03b97e266f30db994d8bc24bca2afd74a106b9 Mon Sep 17 00:00:00 2001 From: "hjwblog.com" Date: Wed, 27 Jul 2022 13:59:47 +0800 Subject: [PATCH 039/231] fix: empty serilizer err #5524 (#5525) * fix: empty serilizer err #5524 * feat: fix UnixSecondSerializer return nil * feat: split type case Co-authored-by: huanjiawei --- schema/field.go | 5 +---- schema/serializer.go | 10 ++++++++-- tests/go.mod | 1 - tests/serializer_test.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index 47f3994f..1589d984 100644 --- a/schema/field.go +++ b/schema/field.go @@ -468,9 +468,6 @@ func (field *Field) setupValuerAndSetter() { oldValuerOf := field.ValueOf field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { value, zero := oldValuerOf(ctx, v) - if zero { - return value, zero - } s, ok := value.(SerializerValuerInterface) if !ok { @@ -483,7 +480,7 @@ func (field *Field) setupValuerAndSetter() { Destination: v, Context: ctx, fieldValue: value, - }, false + }, zero } } diff --git a/schema/serializer.go b/schema/serializer.go index 21be3c35..00a4f85f 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -119,9 +119,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: - result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.Indirect(rv).Int(), 0) + case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + if rv.IsZero() { + return nil, nil + } + result = time.Unix(reflect.Indirect(rv).Int(), 0) default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/tests/go.mod b/tests/go.mod index 7a788a43..eb8f336d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,7 +3,6 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 95d25699..946536bf 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -123,6 +123,35 @@ func TestSerializer(t *testing.T) { } } +func TestSerializerZeroValue(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := SerializerStruct{} + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) + + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } +} + func TestSerializerAssignFirstOrCreate(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(&SerializerStruct{}) From f22327938485f1673eab443949ae92367293c566 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 10 Aug 2022 11:03:42 +0800 Subject: [PATCH 040/231] chore: fix gorm tag (#5577) --- utils/tests/models.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/tests/models.go b/utils/tests/models.go index 22e8e659..ec1651a3 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -64,8 +64,8 @@ type Language struct { type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` - AmountOff uint32 `gorm:"amount_off"` - PercentOff float32 `gorm:"percent_off"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct { From a35883590b7f9467bedf43b9611b2c0d0ff30ffd Mon Sep 17 00:00:00 2001 From: Bruce MacKenzie Date: Wed, 10 Aug 2022 23:38:04 -0400 Subject: [PATCH 041/231] update Delete Godoc to describe soft delete behaviour (#5554) --- finisher_api.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index af9afb63..bdf0437d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -388,7 +388,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. +// If value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { From 573b9fa536050c156968b4d228cab05a119d78df Mon Sep 17 00:00:00 2001 From: enwawerueli Date: Fri, 12 Aug 2022 16:46:18 +0300 Subject: [PATCH 042/231] fix: correct grammar --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index c852e60c..1f1dac21 100644 --- a/gorm.go +++ b/gorm.go @@ -413,7 +413,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { - return fmt.Errorf("failed to found relation: %s", field) + return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { From ba227e8939d05f249a3ede8901193801d8da8603 Mon Sep 17 00:00:00 2001 From: Aoang Date: Mon, 15 Aug 2022 10:46:57 +0800 Subject: [PATCH 043/231] Add Go 1.19 Support (#5608) --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b97da3f4..367f4ccd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -86,7 +86,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From 3f92b9b0df84736750d6645e074596a7383ae089 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Mon, 15 Aug 2022 11:47:26 +0900 Subject: [PATCH 044/231] Refactor: redundant type from composite literal (#5604) --- statement_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/statement_test.go b/statement_test.go index a537c7be..761daf37 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,18 +37,18 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ - "table.name": []string{"table", "name"}, - "`table`.`name`": []string{"table", "name"}, - "'table'.'name'": []string{"table", "name"}, - "'table'.name": []string{"table", "name"}, - "table1.name_23": []string{"table1", "name_23"}, - "`table_1`.`name23`": []string{"table_1", "name23"}, - "'table23'.'name_1'": []string{"table23", "name_1"}, - "'table23'.name1": []string{"table23", "name1"}, - "'name1'": []string{"", "name1"}, - "`name_1`": []string{"", "name_1"}, - "`Name_1`": []string{"", "Name_1"}, - "`Table`.`nAme`": []string{"Table", "nAme"}, + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From 8c3018b96aea241a35b769291de6edd2a3378b44 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Mon, 15 Aug 2022 11:50:06 +0900 Subject: [PATCH 045/231] Replace `ioutil.Discard` with `io.Discard` (#5603) --- go.mod | 2 +- logger/logger.go | 6 +++--- tests/go.mod | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 57362745..03f84379 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.14 +go 1.16 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..ce088561 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "os" "time" @@ -68,8 +68,8 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, diff --git a/tests/go.mod b/tests/go.mod index eb8f336d..19280434 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.14 +go 1.16 require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect From d71caef7d9d08287971a129bc19068eb1f48ed8f Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 3 Sep 2022 20:00:21 +0800 Subject: [PATCH 046/231] fix: remove uuid autoincrement (#5620) --- tests/postgres_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 66b988c3..97af6db3 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -63,13 +63,13 @@ func TestPostgres(t *testing.T) { } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } From f78f635fae6f332a76e8f3e38d939864d1f5c209 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Mon, 5 Sep 2022 15:34:33 +0800 Subject: [PATCH 047/231] Optimize: code logic db.scanIntoStruct() (#5633) --- scan.go | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/scan.go b/scan.go index 6250fb57..2db43160 100644 --- a/scan.go +++ b/scan.go @@ -66,30 +66,32 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}, 0) + joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { - if field != nil { - if len(joinFields) == 0 || joinFields[idx][0] == nil { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil - } - } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) - } - - // release data to pool - field.NewValuePool.Put(values[idx]) + if field == nil { + continue } + + if len(joinFields) == 0 || joinFields[idx][0] == nil { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } + } + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) + } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } From b3eb1c8c512430c1600f720a96b2af777c91d1da Mon Sep 17 00:00:00 2001 From: Jiepeng Cao Date: Mon, 5 Sep 2022 15:39:19 +0800 Subject: [PATCH 048/231] simplified regexp (#5677) --- migrator/migrator.go | 2 +- statement.go | 2 +- tests/upsert_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 87ac7745..c1d7e0e7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,7 +15,7 @@ import ( ) var ( - regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) ) // Migrator m struct diff --git a/statement.go b/statement.go index 12687810..cc26fe37 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f90c4518..e84dc14a 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } From f29afdd3297d94b3e789e1f8d0ab8c823325eba5 Mon Sep 17 00:00:00 2001 From: Bruce MacKenzie Date: Thu, 8 Sep 2022 23:16:41 -0400 Subject: [PATCH 049/231] Rewrite of finisher_api Godocs (#5618) --- finisher_api.go | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index bdf0437d..835a6984 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -13,7 +13,7 @@ import ( "gorm.io/gorm/utils" ) -// Create insert the value into database +// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) @@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { return tx.callbacks.Create().Execute(tx) } -// CreateInBatches insert the value in batches into database +// CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { return } -// First find first record that match given conditions, order by primary key +// First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Last find last record that match given conditions, order by primary key +// Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Find find records that match given conditions +// Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// FindInBatches find records in batches +// FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ @@ -286,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { } } -// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -312,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } -// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ @@ -360,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx } -// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } -// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values @@ -388,8 +390,8 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. -// If value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current // time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() @@ -484,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return rows, tx.Error } -// Scan scan value to a struct +// Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() @@ -509,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -// Pluck used to query single column from a model as a map +// Pluck queries a single column from a model, returning in the slice dest. E.g.: // var ages []int64 // db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { @@ -552,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } -// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error @@ -574,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true @@ -617,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } -// Begin begins a transaction +// Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement @@ -646,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { return tx } -// Commit commit a transaction +// Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) @@ -656,7 +661,7 @@ func (db *DB) Commit() *DB { return db } -// Rollback rollback a transaction +// Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { @@ -686,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB { return db } -// Exec execute raw sql +// Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} From edb00c10adff38445c4350c0cb524faa6ec2d592 Mon Sep 17 00:00:00 2001 From: Googol Lee Date: Wed, 14 Sep 2022 04:26:51 +0200 Subject: [PATCH 050/231] AutoMigrate() should always migrate checks, even there is no relationship constraints. (#5644) * fix: remove uuid autoincrement * AutoMigrate() should always migrate checks, even there is no relationship constranits. Co-authored-by: a631807682 <631807682@qq.com> --- migrator/migrator.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c1d7e0e7..e6782a13 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -135,12 +135,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + } - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { - return err - } + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err } } } From 490625981a1c3474eeca7f2e4fde791cd94c84fa Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:02:44 +0800 Subject: [PATCH 051/231] fix: update omit (#5699) --- callbacks/update.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 48c61bf4..b596df9a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else if _, ok := db.Statement.Clauses["SET"]; !ok { - return + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } } db.Statement.Build(db.Statement.BuildClauses...) From 5ed7b1a65e2aeeb92bb12f2b1ebcac2e4d3402fe Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 11:25:03 +0800 Subject: [PATCH 052/231] fix: same embedded filed name (#5705) --- migrator/migrator.go | 2 +- tests/migrate_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index e6782a13..d7ebf276 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -478,7 +478,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + return m.DB.Migrator().AlterColumn(value, field.DBName) } return nil diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0b5bc5eb..32e84e77 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -959,3 +959,41 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + AssertEqual(t, nil, err) +} From 1f634c39377f914187ae9efb1bc1bdbc94e97028 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Thu, 22 Sep 2022 14:50:35 +0800 Subject: [PATCH 053/231] support scan assign slice cap (#5634) * support scan assign slice cap * fix --- scan.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 2db43160..df5a3714 100644 --- a/scan.go +++ b/scan.go @@ -248,7 +248,13 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } for initialized || rows.Next() { From 3a72ba102ec1ce729f703be4ac00e0049b82b0e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 Sep 2022 17:29:38 +0800 Subject: [PATCH 054/231] Allow shared foreign key for many2many jointable --- schema/relationship.go | 60 ++++++++++++++++++++++--------------- schema/relationship_test.go | 29 +++++++++++++++++- tests/go.mod | 13 ++++---- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 0aa33e51..bb8aeb64 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, @@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), - "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..85c45589 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/tests/go.mod b/tests/go.mod index 19280434..ebebabc0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,18 @@ module gorm.io/gorm/tests go 1.16 require ( + github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.14 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.5 - gorm.io/driver/postgres v1.3.8 + github.com/lib/pq v1.10.7 + github.com/mattn/go-sqlite3 v1.14.15 // indirect + golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + gorm.io/driver/mysql v1.3.6 + gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + gorm.io/gorm v1.23.9 ) replace gorm.io/gorm => ../ From 101a7c789fa2c41f409da439056806756fd8ce22 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 15:51:47 +0800 Subject: [PATCH 055/231] fix: scan array (#5624) Co-authored-by: Jinzhu --- scan.go | 22 +++++++++++++++------- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/scan.go b/scan.go index df5a3714..70cd4284 100644 --- a/scan.go +++ b/scan.go @@ -243,15 +243,18 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value - recyclableStruct := reflect.New(reflectValueType) + var ( + elem reflect.Value + recyclableStruct = reflect.New(reflectValueType) + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else { + } else if !isArrayKind { reflectValue.SetLen(0) db.Statement.ReflectValue.Set(reflectValue) } @@ -285,10 +288,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } @@ -312,4 +320,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} +} \ No newline at end of file diff --git a/tests/query_test.go b/tests/query_test.go index 4569fe1a..eccf0133 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -216,6 +216,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) From 73bc53f061ee1f54b9ef562a3466b5e3c5438aea Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 15:56:32 +0800 Subject: [PATCH 056/231] feat: migrator support type aliases (#5627) * feat: migrator support type aliases * perf: check type --- migrator.go | 1 + migrator/migrator.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/migrator.go b/migrator.go index 34e888f2..882fc4cc 100644 --- a/migrator.go +++ b/migrator.go @@ -68,6 +68,7 @@ type Migrator interface { // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index d7ebf276..29c0c00c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -408,9 +408,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn := false - // check type - if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { - alterColumn = true + if !field.PrimaryKey { + // check type + var isSameType bool + if strings.HasPrefix(fullDataType, realDataType) { + isSameType = true + } + + // check type aliases + if !isSameType { + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + } + + if !isSameType { + alterColumn = true + } } // check size @@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} From 12237454ed695461eb750aee9fca6bac7faa8b8b Mon Sep 17 00:00:00 2001 From: kinggo Date: Thu, 22 Sep 2022 16:47:31 +0800 Subject: [PATCH 057/231] fix: use preparestmt in trasaction will use new conn, close #5508 --- gorm.go | 16 ++++++++++++---- tests/prepared_stmt_test.go | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/gorm.go b/gorm.go index 1f1dac21..81b6e2af 100644 --- a/gorm.go +++ b/gorm.go @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..86e3630d 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "errors" "testing" "time" @@ -88,3 +89,19 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} From 328f3019825c95be6264cc94d3b4c32fe3cf61d1 Mon Sep 17 00:00:00 2001 From: Nguyen Huu Tuan <54979794+nohattee@users.noreply.github.com> Date: Thu, 22 Sep 2022 17:35:21 +0700 Subject: [PATCH 058/231] add some test case which related the logic (#5477) --- schema/schema.go | 8 +++++++ tests/postgres_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/schema/schema.go b/schema/schema.go index 3791237d..42ff5c45 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -239,6 +239,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.HasDefaultValue = true field.AutoIncrement = true } + case String: + if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + } } } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 97af6db3..b5b672a9 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -9,6 +9,56 @@ import ( "gorm.io/gorm" ) +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() From e1dd0dcbc41741e94702d0973df88f4a7afd98e1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 30 Sep 2022 11:13:01 +0800 Subject: [PATCH 059/231] chore(deps): bump actions/stale from 5 to 6 (#5717) Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index aa1812d4..bc4487ae 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index c3c92beb..f9f51aa0 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af8d3636..a9aff43a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From be440e75122de5f7c19e2242a59246a92ce8edfe Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Fri, 30 Sep 2022 11:14:34 +0800 Subject: [PATCH 060/231] fix possible nil panic in tests (#5720) * fix maybe nil panic * reset code --- tests/callbacks_test.go | 3 +++ tests/transaction_test.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 2bf9496b..4479da4c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) { for idx, data := range datas { db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } callbacks := db.Callback() for _, c := range data.callbacks { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 0ac04a04..5872da94 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) { return errors.New("the error message") }) - if err.Error() != "the error message" { + if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } From a3cc6c6088c1e2aa8cbd174f4714e7fc6d0acd59 Mon Sep 17 00:00:00 2001 From: Stephano George Date: Fri, 30 Sep 2022 17:18:42 +0800 Subject: [PATCH 061/231] Fix: wrong value when Find with Join with same column name, close #5723, #5711 --- scan.go | 31 ++++++++++++++----------------- tests/go.mod | 4 ++-- tests/joins_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/scan.go b/scan.go index 70cd4284..3a753dca 100644 --- a/scan.go +++ b/scan.go @@ -163,11 +163,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - selectedColumnsMap = make(map[string]int, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -200,26 +199,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { - schFieldsCount := len(sch.Fields) + matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - if curIndex, ok := selectedColumnsMap[column]; ok { - fields[idx] = field // handle duplicate fields - offset := curIndex + 1 - // handle sch inconsistent with database - // like Raw(`...`).Scan - if schFieldsCount > offset { - for fieldIndex, selectField := range sch.Fields[offset:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ fields[idx] = selectField break } + count-- } } } else { - fields[idx] = field - selectedColumnsMap[column] = idx + matchedFieldCount[column] = 1 } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { diff --git a/tests/go.mod b/tests/go.mod index ebebabc0..c1e1e0ce 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,12 +9,12 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect gorm.io/driver/mysql v1.3.6 gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.9 + gorm.io/gorm v1.23.10 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 4908e5ba..7519db82 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -229,3 +229,34 @@ func TestJoinWithSoftDeleted(t *testing.T) { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} From 0b7113b618584edd76d74e7a73eecc2a28a4d17a Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 30 Sep 2022 18:13:36 +0800 Subject: [PATCH 062/231] fix: prepare deadlock (#5568) * fix: prepare deadlock * chore[ci skip]: code style * chore[ci skip]: test remove unnecessary params * fix: prepare deadlock * fix: double check prepare * test: more goroutines * chore[ci skip]: improve code comments Co-authored-by: Jinzhu --- gorm.go | 2 +- prepare_stmt.go | 54 ++++++++++++++++++++++++------- tests/prepared_stmt_test.go | 63 +++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 12 deletions(-) diff --git a/gorm.go b/gorm.go index 81b6e2af..589fc4ff 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]Stmt{}, + Stmts: map[string](*Stmt){}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index b062b0d6..3934bb97 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,10 +9,12 @@ import ( type Stmt struct { *sql.Stmt Transaction bool + prepared chan struct{} + prepareErr error } type PreparedStmtDB struct { - Stmts map[string]Stmt + Stmts map[string]*Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - return stmt, nil + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } db.Mux.RUnlock() db.Mux.Lock() - defer db.Mux.Unlock() - // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - return stmt, nil - } else if ok { - go stmt.Close() + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) - if err == nil { - db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} - db.PreparedSQL = append(db.PreparedSQL, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err } - return db.Stmts[query], err + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 86e3630d..c7f251f2 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "sync" "errors" "testing" "time" @@ -90,6 +91,68 @@ func TestPreparedStmtFromTransaction(t *testing.T) { tx2.Commit() } +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtError(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + // err prepare + tag := Tag{Locale: "zh"} + tx.Table("users").Find(&tag) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 0) + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"} From 9564b82975844e9e944aefc936968225d9857b86 Mon Sep 17 00:00:00 2001 From: Wen Sun Date: Fri, 7 Oct 2022 14:46:20 +0900 Subject: [PATCH 063/231] Fix OnConstraint builder (#5738) --- clause/on_conflict.go | 34 ++++++++++++++-------------- tests/postgres_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd..032bf4a1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index b5b672a9..f45b2618 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -148,3 +149,53 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} From 4b22a55a752d4284a72545a1611d651b364b3482 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Fri, 7 Oct 2022 18:29:28 +0800 Subject: [PATCH 064/231] fix: primaryFields are overwritten (#5721) --- schema/relationship.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index bb8aeb64..9436f283 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { reguessOrErr() return } + foreignFields = append(foreignFields, f) } } else { - var primaryFields []*Field var primarySchemaName = primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name @@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } - if len(foreignFields) == 0 { + switch { + case len(foreignFields) == 0: reguessOrErr() return - } else if len(relation.primaryKeys) > 0 { + case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { @@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu return } } - } else if len(primaryFields) == 0 { + case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { From e8f48b5c155b6fbf2e1fe6a554e2280f62af21a7 Mon Sep 17 00:00:00 2001 From: robhafner Date: Fri, 7 Oct 2022 08:14:14 -0400 Subject: [PATCH 065/231] fix: limit=0 results (#5735) (#5736) --- chainable_api.go | 2 +- clause/benchmarks_test.go | 3 ++- clause/limit.go | 10 +++++----- clause/limit_test.go | 20 ++++++++++++++------ finisher_api.go | 4 +++- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 68b4d1aa..ab3a1a32 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -244,7 +244,7 @@ func (db *DB) Order(value interface{}) (tx *DB) { // Limit specify the number of records to be retrieved func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Limit{Limit: limit}) + tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index e08677ac..34d5df41 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ @@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, - clause.Limit{Limit: 10, Offset: 20}, + clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/limit.go b/clause/limit.go index 184f6025..3ede7385 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -4,7 +4,7 @@ import "strconv" // Limit limit clause type Limit struct { - Limit int + Limit *int Offset int } @@ -15,12 +15,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(strconv.Itoa(*limit.Limit)) } if limit.Offset > 0 { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit != 0 { + if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index c26294aa..79065ab6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -8,6 +8,10 @@ import ( ) func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string @@ -15,11 +19,15 @@ func TestLimit(t *testing.T) { }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ - Limit: 10, + Limit: &limit10, Offset: 20, }}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET 20", nil, @@ -29,23 +37,23 @@ func TestLimit(t *testing.T) { "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, "SELECT * FROM `users` LIMIT 10", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, }, } diff --git a/finisher_api.go b/finisher_api.go index 835a6984..5516c0a1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { - totalSize = limit.Limit + if limit.Limit != nil { + totalSize = *limit.Limit + } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize From 34fbe84580290c32ba006b714669bb356224cb07 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Oct 2022 21:18:37 +0800 Subject: [PATCH 066/231] Add TableName with NamingStrategy support, close #5726 --- schema/schema.go | 7 +++++++ tests/go.mod | 12 +++++------- tests/table_test.go | 26 ++++++++++++++++++++++++++ utils/tests/dummy_dialecter.go | 10 +++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 42ff5c45..9b3d30f6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,6 +71,10 @@ type Tabler interface { TableName() string } +type TablerWithNamer interface { + TableName(Namer) string +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } diff --git a/tests/go.mod b/tests/go.mod index c1e1e0ce..d28c4bb9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect - gorm.io/driver/mysql v1.3.6 - gorm.io/driver/postgres v1.3.10 - gorm.io/driver/sqlite v1.3.6 - gorm.io/driver/sqlserver v1.3.2 + golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect + gorm.io/driver/mysql v1.4.0 + gorm.io/driver/postgres v1.4.1 + gorm.io/driver/sqlite v1.4.1 + gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.23.10 ) diff --git a/tests/table_test.go b/tests/table_test.go index 0289b7b8..f538c691 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -5,6 +5,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) @@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }}) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 2990c20f..c89b944a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -13,7 +14,14 @@ func (DummyDialector) Name() string { return "dummy" } -func (DummyDialector) Initialize(*gorm.DB) error { +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + return nil } From 983e96f14253c071b8ab3fb96b4c9f103ad39e1c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 16:04:57 +0800 Subject: [PATCH 067/231] Add tests for alter column type --- tests/go.mod | 4 ++-- tests/migrate_test.go | 2 +- tests/postgres_test.go | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index d28c4bb9..3919a838 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,10 +9,10 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.1 + gorm.io/driver/postgres v1.4.3 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 - gorm.io/gorm v1.23.10 + gorm.io/gorm v1.24.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 32e84e77..b918b4b5 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -400,7 +400,7 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index f45b2618..794ab8f7 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -8,6 +8,7 @@ import ( "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -199,3 +200,18 @@ func TestPostgresOnConstraint(t *testing.T) { t.Errorf("expected 1 thing got more") } } + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} From e93dc3426e8cb0a99091e2267ef2adf1cc86b4b5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 17:16:32 +0800 Subject: [PATCH 068/231] Test postgres autoincrement check --- tests/go.mod | 2 +- tests/postgres_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 3919a838..0160b2a6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.3 + gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.24.0 diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 794ab8f7..44cac6bf 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -112,6 +112,45 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } } type Post struct { From 2c56954cb12dd33fc8f1875a735091d61daff702 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 20:48:22 +0800 Subject: [PATCH 069/231] tests mariadb with returning support --- scan.go | 2 +- tests/connpool_test.go | 2 +- tests/go.mod | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 3a753dca..0a26ce4b 100644 --- a/scan.go +++ b/scan.go @@ -317,4 +317,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} \ No newline at end of file +} diff --git a/tests/connpool_test.go b/tests/connpool_test.go index fbae2294..42e029bc 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 0160b2a6..bf59e8d2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.0 + gorm.io/driver/mysql v1.4.1 gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 From 08aa2f9888dcd3c950943d09d0d7aaef1b1dcc33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 14 Oct 2022 20:30:28 +0800 Subject: [PATCH 070/231] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 312a3a59..5bb1be37 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) -* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing From aa4312ee74db5a23d459d487b43a4a79d341c936 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Oct 2022 15:57:10 +0800 Subject: [PATCH 071/231] Don't display any GORM related package path as source --- utils/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index 296917b9..90b4c8ea 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -16,7 +16,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "") } // FileWithLineNum return the file name and line number of the current file From 2a788fb20c3cbc73e96aa422b7477fe62d23964a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Oct 2022 17:01:00 +0800 Subject: [PATCH 072/231] Upgrade tests go.mod --- tests/go.mod | 10 +++++----- tests/sql_builder_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index bf59e8d2..2fef9d97 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,15 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.1 + golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect + golang.org/x/text v0.3.8 // indirect + gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.1 - gorm.io/driver/sqlserver v1.4.0 + gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a9b920dc..b10142fa 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } - date, _ := time.Parse("2006-01-02", "2021-10-18") + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { From 186e8a9e14578c63715444d217294065be072805 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 18 Oct 2022 11:58:42 +0800 Subject: [PATCH 073/231] fix: association without pks (#5779) --- callbacks/associations.go | 10 +++++++-- tests/associations_test.go | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 00e00fcc..9d7c1412 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + if isPtr { elems = reflect.Append(elems, elem) } else { @@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + distinctElems = reflect.Append(distinctElems, elem) } diff --git a/tests/associations_test.go b/tests/associations_test.go index 42b32afc..4c9076da 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) { AssertEqual(t, len(orgs), 0) } } + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} From ab5f80a8d81c1955e92224b24dfc9bc8c7d387a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 15:44:47 +0800 Subject: [PATCH 074/231] Save as NULL for nil object serialized into json --- schema/serializer.go | 3 +++ tests/go.mod | 4 ++-- tests/serializer_test.go | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 00a4f85f..fef39d9b 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -100,6 +100,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) + if string(result) == "null" { + return nil, err + } return string(result), err } diff --git a/tests/go.mod b/tests/go.mod index 2fef9d97..9c87ca34 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,10 +7,10 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect - golang.org/x/text v0.3.8 // indirect + golang.org/x/text v0.4.0 // indirect gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 946536bf..17bfefe2 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -18,6 +18,7 @@ type SerializerStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -108,7 +109,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From a0f4d3f7d207b2103b5f91e9758b1ac6a94056ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 16:25:39 +0800 Subject: [PATCH 075/231] Save as empty string for not nullable nil field serialized into json --- schema/serializer.go | 3 +++ tests/serializer_test.go | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/schema/serializer.go b/schema/serializer.go index fef39d9b..9a6aa4fc 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -101,6 +101,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } return nil, err } return string(result), err diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 17bfefe2..a040a4db 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,6 +19,7 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -109,7 +110,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From 62593cfad03ebf1e6cae30bac010655b4a28ff67 Mon Sep 17 00:00:00 2001 From: viatoriche / Maxim Panfilov Date: Tue, 18 Oct 2022 17:28:06 +0800 Subject: [PATCH 076/231] add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 --- migrator/migrator.go | 55 +++++++++++++++++++++---------------------- tests/migrate_test.go | 40 +++++++++++++++++++++++++++++++ tests/tracer_test.go | 34 ++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 28 deletions(-) create mode 100644 tests/tracer_test.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 29c0c00c..9f8e3db8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn, isSameType bool + ) if !field.PrimaryKey { // check type - var isSameType bool - if strings.HasPrefix(fullDataType, realDataType) { - isSameType = true - } - - // check type aliases - if !isSameType { + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases aliases := m.DB.Migrator().GetTypeAliases(realDataType) for _, alias := range aliases { if strings.HasPrefix(fullDataType, alias) { @@ -424,32 +421,34 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy break } } - } - if !isSameType { - alterColumn = true - } - } - - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if !field.PrimaryKey && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b918b4b5..8718aa57 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "fmt" "math/rand" "reflect" @@ -9,6 +10,7 @@ import ( "time" "gorm.io/driver/postgres" + "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -72,6 +74,44 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } + +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } } func TestAutoMigrateSelfReferential(t *testing.T) { diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} From 3f20a543fad5f57016ef7a6c342536b0fcce6016 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 18:01:55 +0800 Subject: [PATCH 077/231] Support use clause.Interface as query params --- statement.go | 4 ++++ tests/sql_builder_test.go | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/statement.go b/statement.go index cc26fe37..d05d299e 100644 --- a/statement.go +++ b/statement.go @@ -179,6 +179,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index b10142fa..0fbd6118 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) { if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. From 5dd2bb482755f5e8eb5ecaff39e675fb62f19a20 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 19 Oct 2022 14:46:59 +0800 Subject: [PATCH 078/231] feat(PreparedStmtDB): support reset (#5782) * feat(PreparedStmtDB): support reset * fix: close all stmt * test: fix test * fix: delete one by one --- prepare_stmt.go | 12 ++++++++++++ tests/prepared_stmt_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97..7591e533 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + for query, stmt := range db.Stmts { + delete(db.Stmts, query) + go stmt.Close() + } + + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2..64baa01b 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} From 9d82aa56734999bb28e0c4d60fba69ae7cde66d5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 20 Oct 2022 14:10:47 +0800 Subject: [PATCH 079/231] test: invalid cache plan with prepare stmt (#5778) * test: invalid cache plan with prepare stmt * test: more test cases * test: drop and rename column --- tests/migrate_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 8718aa57..96b1d0e4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "os" "reflect" "strings" "testing" @@ -12,6 +13,7 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -890,7 +892,7 @@ func findColumnType(dest interface{}, columnName string) ( return } -func TestInvalidCachedPlan(t *testing.T) { +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } @@ -925,6 +927,101 @@ func TestInvalidCachedPlan(t *testing.T) { } } +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { type DiffType struct { ID uint From b2f42528a48aeed9612d43e19cdf4fe8e87a27a3 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 2 Nov 2022 10:28:00 +0800 Subject: [PATCH 080/231] fix(Joins): args with select and omit (#5790) * fix(Joins): args with select and omit * chore: gofumpt style --- callbacks/query.go | 18 ++++++++++++----- chainable_api.go | 49 ++++++++++++++++++++++++++------------------- statement.go | 13 +++++++----- tests/joins_test.go | 43 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 31 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 26ee8c34..67936766 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } } exprs := make([]clause.Expression, len(relation.References)) diff --git a/chainable_api.go b/chainable_api.go index ab3a1a32..6d48d56b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -10,10 +10,11 @@ import ( ) // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value @@ -179,18 +180,21 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions -// db.Joins("Account").Find(&user) -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { + j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) - return + j.On = &where } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return } } @@ -219,8 +223,9 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { } // Order specify order when retrieve records from database -// db.Order("name DESC") -// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() @@ -256,17 +261,18 @@ func (db *DB) Offset(offset int) (tx *DB) { } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) @@ -274,7 +280,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { diff --git a/statement.go b/statement.go index d05d299e..d4d20cbf 100644 --- a/statement.go +++ b/statement.go @@ -49,9 +49,11 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string } // StatementModifier statement modifier interface @@ -544,8 +546,9 @@ func (stmt *Statement) clone() *Statement { } // SetColumn set column's value -// stmt.SetColumn("Name", "jinzhu") // Hooks Method -// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value diff --git a/tests/joins_test.go b/tests/joins_test.go index 7519db82..091fb986 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) { t.Fatalf("wrong pet name") } } + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} From f82e9cfdbed051e8e397e2fd1f7ab62c17ff8a4f Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Thu, 3 Nov 2022 21:03:13 +0800 Subject: [PATCH 081/231] test(clause/joins): add join unit test (#5832) --- clause/joins.go | 2 +- clause/joins_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 clause/joins_test.go diff --git a/clause/joins.go b/clause/joins.go index f3e373f2..879892be 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -9,7 +9,7 @@ const ( RightJoin JoinType = "RIGHT" ) -// Join join clause for from +// Join clause for from type Join struct { Type JoinType Table Table diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +} From 5c8ecc3a2ad2aa570ecc0bb947138539a1bad9cf Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sat, 5 Nov 2022 08:37:37 +0800 Subject: [PATCH 082/231] feat: golangci add goimports and whitespace (#5835) --- .golangci.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.golangci.yml b/.golangci.yml index 16903ed6..b88bf672 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,3 +9,12 @@ linters: - prealloc - unconvert - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + From fb640cf7daee5a4c6b738299a711612624112de7 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sat, 5 Nov 2022 08:38:14 +0800 Subject: [PATCH 083/231] test(utils): add utils unit test (#5834) --- utils/utils_test.go | 106 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/utils/utils_test.go b/utils/utils_test.go index 27dfee16..71eef964 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,8 +1,13 @@ package utils import ( + "database/sql" + "database/sql/driver" + "errors" + "math" "strings" "testing" + "time" ) func TestIsValidDBNameChar(t *testing.T) { @@ -13,6 +18,29 @@ func TestIsValidDBNameChar(t *testing.T) { } } +func TestCheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + if out := CheckTruth(test.v); out != test.out { + t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) + } + }) + } +} + func TestToStringKey(t *testing.T) { cases := []struct { values []interface{} @@ -29,3 +57,81 @@ func TestToStringKey(t *testing.T) { } } } + +func TestContains(t *testing.T) { + containsTests := []struct { + name string + elems []string + elem string + out bool + }{ + {"exists", []string{"1", "2", "3"}, "1", true}, + {"not exists", []string{"1", "2", "3"}, "4", false}, + } + for _, test := range containsTests { + t.Run(test.name, func(t *testing.T) { + if out := Contains(test.elems, test.elem); test.out != out { + t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) + } + }) + } +} + +type ModifyAt sql.NullTime + +// Value return a Unix time. +func (n ModifyAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time.Unix(), nil +} + +func TestAssertEqual(t *testing.T) { + now := time.Now() + assertEqualTests := []struct { + name string + src, dst interface{} + out bool + }{ + {"error equal", errors.New("1"), errors.New("1"), true}, + {"error not equal", errors.New("1"), errors.New("2"), false}, + {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, + {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + } + for _, test := range assertEqualTests { + t.Run(test.name, func(t *testing.T) { + if out := AssertEqual(test.src, test.dst); test.out != out { + t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) + } + }) + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + in interface{} + out string + }{ + {"int", math.MaxInt64, "9223372036854775807"}, + {"int8", int8(math.MaxInt8), "127"}, + {"int16", int16(math.MaxInt16), "32767"}, + {"int32", int32(math.MaxInt32), "2147483647"}, + {"int64", int64(math.MaxInt64), "9223372036854775807"}, + {"uint", uint(math.MaxUint64), "18446744073709551615"}, + {"uint8", uint8(math.MaxUint8), "255"}, + {"uint16", uint16(math.MaxUint16), "65535"}, + {"uint32", uint32(math.MaxUint32), "4294967295"}, + {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, + {"string", "abc", "abc"}, + {"other", true, ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if out := ToString(test.in); test.out != out { + t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) + } + }) + } +} From 871f1de6b93835b069b6ef1bcbd823047a47c7a9 Mon Sep 17 00:00:00 2001 From: kvii <56432636+kvii@users.noreply.github.com> Date: Sat, 5 Nov 2022 11:52:08 +0800 Subject: [PATCH 084/231] fix logger path bug (#5836) --- utils/utils.go | 15 +++++++++++++-- utils/utils_unix_test.go | 33 +++++++++++++++++++++++++++++++++ utils/utils_windows_test.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 utils/utils_unix_test.go create mode 100644 utils/utils_windows_test.go diff --git a/utils/utils.go b/utils/utils.go index 90b4c8ea..2d87f4c2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" @@ -16,7 +16,18 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return s + string(filepath.Separator) } // FileWithLineNum return the file name and line number of the current file diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go new file mode 100644 index 00000000..da97aa2c --- /dev/null +++ b/utils/utils_unix_test.go @@ -0,0 +1,33 @@ +package utils + +import "testing" + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/Users/name/go/pkg/mod/gorm.io/", + }, + { + file: "/go/work/proj/gorm/utils/utils.go", + want: "/go/work/proj/gorm/", + }, + { + file: "/go/work/proj/gorm_alias/utils/utils.go", + want: "/go/work/proj/gorm_alias/", + }, + { + file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go new file mode 100644 index 00000000..d1734e0e --- /dev/null +++ b/utils/utils_windows_test.go @@ -0,0 +1,33 @@ +package utils + +import "testing" + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: `C:\Users\name\go\pkg\mod\gorm.io\gorm@v1.20.8\utils\utils.go`, + want: `C:\Users\name\go\pkg\mod\gorm.io`, + }, + { + file: `C:\go\work\proj\gorm\utils\utils.go`, + want: `C:\go\work\proj\gorm`, + }, + { + file: `C:\go\work\proj\gorm_alias\utils\utils.go`, + want: `C:\go\work\proj\gorm_alias`, + }, + { + file: `C:\go\work\proj\my.gorm.io\gorm\utils\utils.go`, + want: `C:\go\work\proj\my.gorm.io\gorm`, + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} From 1b9cd56c5336ba6e22936c289e586261b75d7b35 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Thu, 10 Nov 2022 16:30:32 +0800 Subject: [PATCH 085/231] doc(README.md): add contributors (#5847) --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 5bb1be37..68fa6603 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ The fantastic ORM library for Golang, aims to be developer friendly. [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) +## Contributors + +Thank you for contributing to the GORM framework! + +[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors) + ## License © Jinzhu, 2013~time.Now From cef3de694d9615c574e82dfa0b50fc7ea2816f3e Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sun, 13 Nov 2022 11:12:09 +0800 Subject: [PATCH 086/231] cleanup(prepare_stmt.go): unnecessary map delete (#5849) --- gorm.go | 2 +- prepare_stmt.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gorm.go b/gorm.go index 589fc4ff..89488b75 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string](*Stmt){}, + Stmts: make(map[string]*Stmt), Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index 7591e533..e09fe814 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -47,13 +47,12 @@ func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Reset() { db.Mux.Lock() defer db.Mux.Unlock() - for query, stmt := range db.Stmts { - delete(db.Stmts, query) + + for _, stmt := range db.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = map[string](*Stmt){} + db.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -93,7 +92,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact // Reason why cannot lock conn.PrepareContext // suppose the maxopen is 1, g1 is creating record and g2 is querying record. - // 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) From b6836c2d3ee91c0f0114736084d033f2b0a96748 Mon Sep 17 00:00:00 2001 From: kvii <56432636+kvii@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:48:13 +0800 Subject: [PATCH 087/231] fix bug in windows (#5844) * fix bug in windows * fix file name bug * test in unix like platform --- utils/utils.go | 2 +- utils/utils_unix_test.go | 7 ++++++- utils/utils_windows_test.go | 20 +++++++++++--------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index 2d87f4c2..e08533cd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -27,7 +27,7 @@ func sourceDir(file string) string { if filepath.Base(s) != "gorm.io" { s = dir } - return s + string(filepath.Separator) + return filepath.ToSlash(s) + "/" } // FileWithLineNum return the file name and line number of the current file diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go index da97aa2c..450cbe2a 100644 --- a/utils/utils_unix_test.go +++ b/utils/utils_unix_test.go @@ -1,6 +1,11 @@ +//go:build unix +// +build unix + package utils -import "testing" +import ( + "testing" +) func TestSourceDir(t *testing.T) { cases := []struct { diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go index d1734e0e..8b1c519d 100644 --- a/utils/utils_windows_test.go +++ b/utils/utils_windows_test.go @@ -1,6 +1,8 @@ package utils -import "testing" +import ( + "testing" +) func TestSourceDir(t *testing.T) { cases := []struct { @@ -8,20 +10,20 @@ func TestSourceDir(t *testing.T) { want string }{ { - file: `C:\Users\name\go\pkg\mod\gorm.io\gorm@v1.20.8\utils\utils.go`, - want: `C:\Users\name\go\pkg\mod\gorm.io`, + file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/Users/name/go/pkg/mod/gorm.io/`, }, { - file: `C:\go\work\proj\gorm\utils\utils.go`, - want: `C:\go\work\proj\gorm`, + file: `C:/go/work/proj/gorm/utils/utils.go`, + want: `C:/go/work/proj/gorm/`, }, { - file: `C:\go\work\proj\gorm_alias\utils\utils.go`, - want: `C:\go\work\proj\gorm_alias`, + file: `C:/go/work/proj/gorm_alias/utils/utils.go`, + want: `C:/go/work/proj/gorm_alias/`, }, { - file: `C:\go\work\proj\my.gorm.io\gorm\utils\utils.go`, - want: `C:\go\work\proj\my.gorm.io\gorm`, + file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, }, } for _, c := range cases { From 342310fba4fc56decf3d417925326db483734d7e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 21 Nov 2022 10:49:27 +0800 Subject: [PATCH 088/231] fix(FindInBatches): throw err if pk not exists (#5868) --- finisher_api.go | 11 ++++++++--- tests/query_test.go | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5516c0a1..cc07a126 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -231,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -514,8 +518,9 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { } // Pluck queries a single column from a model, returning in the slice dest. E.g.: -// var ages []int64 -// db.Model(&users).Pluck("age", &ages) +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { diff --git a/tests/query_test.go b/tests/query_test.go index eccf0133..fa8f09e8 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -408,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) { if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } } func TestFillSmallerStruct(t *testing.T) { From f91313436abcfe7a28a488d5d6777b31a94f24fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Nov 2022 11:10:56 +0800 Subject: [PATCH 089/231] Fix group by with count logic --- finisher_api.go | 2 +- tests/count_test.go | 2 +- tests/go.mod | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index cc07a126..33d7a5a6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { + if tx.RowsAffected != 1 { *count = tx.RowsAffected } diff --git a/tests/count_test.go b/tests/count_test.go index b71e3de5..2199dc6d 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -141,7 +141,7 @@ func TestCount(t *testing.T) { } DB.Create(sameUsers) - if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } diff --git a/tests/go.mod b/tests/go.mod index 9c87ca34..23fc2cad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,13 +6,13 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect - golang.org/x/text v0.4.0 // indirect - gorm.io/driver/mysql v1.4.3 - gorm.io/driver/postgres v1.4.4 + github.com/mattn/go-sqlite3 v1.14.16 // indirect + golang.org/x/crypto v0.3.0 // indirect + gorm.io/driver/mysql v1.4.4 + gorm.io/driver/postgres v1.4.5 gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlserver v1.4.1 - gorm.io/gorm v1.24.0 + gorm.io/gorm v1.24.2 ) replace gorm.io/gorm => ../ From f931def33d23c9fd3c23ccb276e0f8bc17f8337f Mon Sep 17 00:00:00 2001 From: wjw1758548031 <46154774+wjw1758548031@users.noreply.github.com> Date: Thu, 1 Dec 2022 20:25:53 +0800 Subject: [PATCH 090/231] clear code syntax (#5889) * clear code syntax * clear code syntax --- finisher_api.go | 75 +++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 33d7a5a6..b30ca24d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,45 +326,48 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if result := queryTx.Find(dest, conds...); result.Error == nil { - if result.RowsAffected == 0 { - if c, ok := result.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - result.assignInterfacesToValue(where.Exprs) - } - } - // initialize with attrs, conds - if len(db.Statement.attrs) > 0 { - result.assignInterfacesToValue(db.Statement.attrs...) - } - - // initialize with attrs, conds - if len(db.Statement.assigns) > 0 { - result.assignInterfacesToValue(db.Statement.assigns...) - } - - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: - } - } - } - - return tx.Model(dest).Updates(assigns) - } - } else { + result := queryTx.Find(dest, conds...) + if result.Error != nil { tx.Error = result.Error + return tx } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + return tx } From d9525d4da45d343cdfb8641a72735330b9e86c88 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 1 Dec 2022 20:26:59 +0800 Subject: [PATCH 091/231] fix: skip append relation field to default db value (#5885) * fix: relation field returning * chore: gofumpt style --- schema/schema.go | 2 +- tests/associations_belongs_to_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/schema/schema.go b/schema/schema.go index 9b3d30f6..21e71c21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -230,7 +230,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } for _, field := range schema.Fields { - if field.HasDefaultValue && field.DefaultValueInterface == nil { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index f74799ce..a1f014d9 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -224,3 +225,28 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestBelongsToDefaultValue(t *testing.T) { + type Org struct { + ID string + } + type BelongsToUser struct { + OrgID string + Org Org `gorm:"default:NULL"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Config.DisableForeignKeyConstraintWhenMigrating = true + AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) + + tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) + tx.AutoMigrate(&BelongsToUser{}, &Org{}) + + user := &BelongsToUser{ + Org: Org{ + ID: "BelongsToUser_Org_1", + }, + } + err := DB.Create(&user).Error + AssertEqual(t, err, nil) +} From 4ec73c9bf46662bfef7a87d766e9c34661846385 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Mon, 19 Dec 2022 04:49:05 +0100 Subject: [PATCH 092/231] Add test case for embedded value selects (#5901) * Add test case for embedded value selects * Revert recycle struct optimisation to avoid pointer overwrites --- scan.go | 12 +++--------- tests/embedded_struct_test.go | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/scan.go b/scan.go index 0a26ce4b..12a77862 100644 --- a/scan.go +++ b/scan.go @@ -65,7 +65,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { if field == nil { @@ -241,9 +240,8 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - elem reflect.Value - recyclableStruct = reflect.New(reflectValueType) - isArrayKind = reflectValue.Kind() == reflect.Array + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array ) if !update || reflectValue.Len() == 0 { @@ -275,11 +273,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } } } else { - if isPtr && db.RowsAffected > 0 { - elem = reflect.New(reflectValueType) - } else { - elem = recyclableStruct - } + elem = reflect.New(reflectValueType) } db.scanIntoStruct(rows, elem, values, fields, joinFields) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index e309d06c..ae69baca 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -36,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) { type EngadgetPost struct { BasePost BasePost `gorm:"Embedded"` - Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct ImageUrl string } @@ -74,13 +74,27 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("embedded struct's value should be scanned correctly") } - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) var egNews EngadgetPost if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") } + + var egPosts []EngadgetPost + if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { + t.Fatalf("no error should happen when query with embedded struct, but got %v", err) + } + expectAuthors := []string{"Edward", "George"} + for i, post := range egPosts { + t.Log(i, post.Author) + if want := expectAuthors[i]; post.Author.Name != want { + t.Errorf("expected author %s got %s", want, post.Author.Name) + } + } + } func TestEmbeddedPointerTypeStruct(t *testing.T) { From f3c6fc253356919e8ebbcf7bc50e8c7fe88802aa Mon Sep 17 00:00:00 2001 From: Nate Armstrong Date: Fri, 23 Dec 2022 00:51:01 -0800 Subject: [PATCH 093/231] Update func comments in chainable_api and FirstOr_ (#5935) Add comments to functions in chainable_api. Depending on the method, these comments add some additional context or details that are relevant when reading the function, link to the actual docs at gorm.io/docs, or provide examples of use. These comments should make GORM much more pleasant to use with an IDE that provides hoverable comments, and are minimal examples. Also add in-code documentation to FirstOrInit and FirstOrCreate. Almost all examples are directly pulled from the docs, with short comments explaining the code. Most examples omit the `db.Model(&User{})` for brevity, and would not actually work. Co-authored-by: Nate Armstrong --- chainable_api.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++- finisher_api.go | 22 ++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 6d48d56b..68ec7a67 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -13,7 +13,7 @@ import ( // // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() @@ -22,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) { } // Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} @@ -45,6 +58,9 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { @@ -66,6 +82,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } // Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Distinct = true @@ -76,6 +97,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { } // Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -153,6 +182,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) { } // Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -162,6 +202,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { } // Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -171,6 +216,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { } // Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -203,6 +253,9 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { } // Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() @@ -214,6 +267,9 @@ func (db *DB) Group(name string) (tx *DB) { } // Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ @@ -222,7 +278,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { return } -// Order specify order when retrieve records from database +// Order specify order when retrieving records from database // // db.Order("name DESC") // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) @@ -247,6 +303,13 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Limit: &limit}) @@ -254,6 +317,13 @@ func (db *DB) Limit(limit int) (tx *DB) { } // Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) @@ -281,6 +351,7 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { // Preload preload associations with given conditions // +// // get all users, and preload all non-cancelled orders // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -291,12 +362,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.attrs = attrs return } +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.assigns = attrs diff --git a/finisher_api.go b/finisher_api.go index b30ca24d..39d9fca3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,6 +294,16 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. // Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -321,6 +331,18 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. // Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ From bbd2bbe5217f7d3d3df5835748954f3cae6ebb68 Mon Sep 17 00:00:00 2001 From: Ning Date: Sat, 24 Dec 2022 11:02:11 +0800 Subject: [PATCH 094/231] fix:Issue migrating field with CURRENT_TIMESTAMP (#5906) Co-authored-by: ningfei --- migrator/migrator.go | 10 ++++++---- tests/migrate_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 9f8e3db8..b113b398 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -470,17 +470,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check default value if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && !strings.EqualFold(field.DefaultValue, "NULL") dv, dvNotNull := columnType.DefaultValue() - if dvNotNull && field.DefaultValueInterface == nil { + if dvNotNull && !currentDefaultNotNull { // defalut value -> null alterColumn = true - } else if !dvNotNull && field.DefaultValueInterface != nil { + } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if dv != field.DefaultValue { + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { // default value not equal // not both null - if !(field.DefaultValueInterface == nil && !dvNotNull) { + if currentDefaultNotNull || dvNotNull { alterColumn = true } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 96b1d0e4..5f7e0749 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -757,6 +757,32 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestCurrentTimestamp(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + type CurrentTimestampTest struct { + ID string `gorm:"primary_key"` + TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` + } + var err error + err = DB.Migrator().DropTable(&CurrentTimestampTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) +} + func TestUniqueColumn(t *testing.T) { if DB.Dialector.Name() != "mysql" { return From 775fa70af5a727f15ded94761fce5a1076603ca6 Mon Sep 17 00:00:00 2001 From: Defoo Li Date: Sat, 24 Dec 2022 12:14:23 +0800 Subject: [PATCH 095/231] DryRun for migrator (#5689) * DryRun for migrator * Update migrator.go * Update migrator.go Co-authored-by: Jinzhu --- migrator/migrator.go | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index b113b398..eafe7bb2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -8,9 +8,11 @@ import ( "reflect" "regexp" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) @@ -30,6 +32,16 @@ type Config struct { gorm.Dialector } +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string @@ -92,14 +104,19 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) - if !tx.Migrator().HasTable(value) { - if err := tx.Migrator().CreateTable(value); err != nil { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, err := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } @@ -117,10 +134,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, dbName); err != nil { + if err := execTx.Migrator().AddColumn(value, dbName); err != nil { return err } - } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { // found, smart migrate return err } @@ -129,8 +146,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } @@ -138,16 +155,16 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } } for _, idx := range stmt.Schema.ParseIndexes() { - if !tx.Migrator().HasIndex(value, idx.Name) { - if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } From 1935eb0adbd1a05c8eee127fd410b1e5477e1931 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 24 Dec 2022 12:27:38 +0800 Subject: [PATCH 096/231] feat: support inner join (#5583) * feat: support inner join * test: mixed inner join and left join * chore: code comment * Update statement.go Co-authored-by: Jinzhu --- callbacks/query.go | 2 +- chainable_api.go | 12 +++++++++++- statement.go | 11 ++++++----- tests/joins_test.go | 22 ++++++++++++++++++++++ 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 67936766..97fe8a49 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -185,7 +185,7 @@ func BuildQuerySQL(db *gorm.DB) { } fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: clause.LeftJoin, + Type: join.JoinType, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) diff --git a/chainable_api.go b/chainable_api.go index 68ec7a67..8a92a9e3 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -235,6 +235,16 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +// InnerJoins specify inner joins conditions +// db.InnerJoins("Account").Find(&user) +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { @@ -248,7 +258,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { } } - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) return } diff --git a/statement.go b/statement.go index d4d20cbf..9f49d584 100644 --- a/statement.go +++ b/statement.go @@ -49,11 +49,12 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where - Selects []string - Omits []string + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType } // StatementModifier statement modifier interface diff --git a/tests/joins_test.go b/tests/joins_test.go index 091fb986..057ad333 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -230,6 +230,28 @@ func TestJoinWithSoftDeleted(t *testing.T) { } } +func TestInnerJoins(t *testing.T) { + user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + var err error + err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user2, user) + + // inner join and NamedPet is nil + err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, gorm.ErrRecordNotFound) + + // mixed inner join and left join + var user3 User + err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user3, user) +} + func TestJoinWithSameColumnName(t *testing.T) { user := GetUser("TestJoinWithSameColumnName", Config{ Languages: 1, From 794edad60e14692e6716217f73cc989e45b35115 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 24 Dec 2022 17:42:16 +0800 Subject: [PATCH 097/231] test(MigrateColumn): mock alter column to improve field compare (#5499) * test(MigrateColumn): mock alter column to improve field compare * Update migrate_test.go * Update migrate_test.go * Update migrate_test.go Co-authored-by: Jinzhu --- tests/migrate_test.go | 47 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5f7e0749..9df626fd 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -11,7 +11,6 @@ import ( "time" "gorm.io/driver/postgres" - "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -29,7 +28,7 @@ func TestMigrate(t *testing.T) { } if err := DB.AutoMigrate(allModels...); err != nil { - t.Fatalf("Failed to auto migrate, but got error %v", err) + t.Fatalf("Failed to auto migrate, got error %v", err) } if tables, err := DB.Migrator().GetTables(); err != nil { @@ -1123,6 +1122,50 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } +type mockMigrator struct { + gorm.Migrator +} + +func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { + err := mm.Migrator.AlterColumn(dst, field) + if err != nil { + return err + } + return fmt.Errorf("trigger alter column error, field: %s", field) +} + +func TestMigrateDonotAlterColumn(t *testing.T) { + var wrapMockMigrator = func(m gorm.Migrator) mockMigrator { + return mockMigrator{ + Migrator: m, + } + } + m := DB.Migrator() + mockM := wrapMockMigrator(m) + + type NotTriggerUpdate struct { + ID uint + F1 uint16 + F2 uint32 + F3 int + F4 int64 + F5 string + F6 float32 + F7 float64 + F8 time.Time + F9 bool + F10 []byte + } + + var err error + err = mockM.DropTable(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) +} + func TestMigrateSameEmbeddedFieldName(t *testing.T) { type UserStat struct { GroundDestroyCount int From ddd3cc2502eb0a0193e10ec6360d5e83d19493a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 25 Dec 2022 11:37:23 +0800 Subject: [PATCH 098/231] Add ParameterizedQueries option support for logger, close #5288 --- callbacks.go | 6 +++++- gorm.go | 12 ++++++------ interfaces.go | 4 ++++ logger/logger.go | 10 ++++++++++ tests/go.mod | 5 ++++- 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/callbacks.go b/callbacks.go index c060ea70..ebebf79d 100644 --- a/callbacks.go +++ b/callbacks.go @@ -132,7 +132,11 @@ func (p *processor) Execute(db *DB) *DB { if stmt.SQL.Len() > 0 { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + sql, vars := stmt.SQL.String(), stmt.Vars + if filter, ok := db.Logger.(ParamsFilter); ok { + sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...) + } + return db.Dialector.Explain(sql, vars...), db.RowsAffected }, db.Error) } diff --git a/gorm.go b/gorm.go index 89488b75..65c9e228 100644 --- a/gorm.go +++ b/gorm.go @@ -464,12 +464,12 @@ func (db *DB) Use(plugin Plugin) error { // ToSQL for generate SQL string. // -// db.ToSQL(func(tx *gorm.DB) *gorm.DB { -// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) -// .Limit(10).Offset(5) -// .Order("name ASC") -// .First(&User{}) -// }) +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement diff --git a/interfaces.go b/interfaces.go index 32d49605..cf9e07b9 100644 --- a/interfaces.go +++ b/interfaces.go @@ -26,6 +26,10 @@ type Plugin interface { Initialize(*DB) error } +type ParamsFilter interface { + ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) +} + // ConnPool db conns pool interface type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) diff --git a/logger/logger.go b/logger/logger.go index ce088561..29027205 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -55,6 +55,7 @@ type Config struct { SlowThreshold time.Duration Colorful bool IgnoreRecordNotFoundError bool + ParameterizedQueries bool LogLevel LogLevel } @@ -75,6 +76,7 @@ var ( SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, + ParameterizedQueries: true, Colorful: true, }) // Recorder Recorder logger records running SQL into a recorder instance @@ -181,6 +183,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } } +// Trace print sql message +func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Config.ParameterizedQueries { + return sql, nil + } + return sql, params +} + type traceRecorder struct { Interface BeginAt time.Time diff --git a/tests/go.mod b/tests/go.mod index 23fc2cad..3929b334 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,11 +3,14 @@ module gorm.io/gorm/tests go 1.16 require ( + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/google/uuid v1.3.0 + github.com/jackc/pgtype v1.13.0 // indirect github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect - golang.org/x/crypto v0.3.0 // indirect + github.com/microsoft/go-mssqldb v0.19.0 // indirect + golang.org/x/crypto v0.4.0 // indirect gorm.io/driver/mysql v1.4.4 gorm.io/driver/postgres v1.4.5 gorm.io/driver/sqlite v1.4.3 From 7da24d1d52be944fe5058792f8bdcf9572b48a1f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Dec 2022 08:47:17 +0800 Subject: [PATCH 099/231] chore(deps): bump actions/stale from 6 to 7 (#5945) Bumps [actions/stale](https://github.com/actions/stale) from 6 to 7. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index bc4487ae..77b26abe 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v6 + uses: actions/stale@v7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index f9f51aa0..1efa3611 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v6 + uses: actions/stale@v7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index a9aff43a..43f2f730 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v6 + uses: actions/stale@v7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From da2b2861de47900edc6d0b1898bbdd5d5381b412 Mon Sep 17 00:00:00 2001 From: Haibo Date: Sun, 1 Jan 2023 19:54:28 +0800 Subject: [PATCH 100/231] fix(migrator): ignore relationships when migrating #5913 (#5946) --- gorm.go | 2 ++ migrator/migrator.go | 55 ++++++++++++++++++++++--------------- tests/migrate_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/gorm.go b/gorm.go index 65c9e228..37595ddd 100644 --- a/gorm.go +++ b/gorm.go @@ -37,6 +37,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // IgnoreRelationshipsWhenMigrating + IgnoreRelationshipsWhenMigrating bool // DisableNestedTransaction disable nested transaction DisableNestedTransaction bool // AllowGlobalUpdate allow global update diff --git a/migrator/migrator.go b/migrator/migrator.go index eafe7bb2..ebd9bc12 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -143,8 +143,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } - for _, rel := range stmt.Schema.Relationships.Relations { - if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { @@ -244,8 +247,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } - for _, rel := range stmt.Schema.Relationships.Relations { - if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { sql, vars := buildConstraint(constraint) @@ -818,26 +824,31 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } parsedSchemas[dep.Statement.Schema] = true - for _, rel := range dep.Schema.Relationships.Relations { - if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { - dep.Depends = append(dep.Depends, c.ReferenceSchema) - } + if !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range dep.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } - if rel.Type == schema.HasOne || rel.Type == schema.HasMany { - beDependedOn[rel.FieldSchema] = true - } + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } - if rel.JoinTable != nil { - // append join value - defer func(rel *schema.Relationship, joinValue interface{}) { - if !beDependedOn[rel.FieldSchema] { - dep.Depends = append(dep.Depends, rel.FieldSchema) - } else { - fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() - parseDependence(fieldValue, autoAdd) - } - parseDependence(joinValue, autoAdd) - }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 9df626fd..d5d129a8 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1203,3 +1203,67 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") AssertEqual(t, nil, err) } + +func TestMigrateIgnoreRelations(t *testing.T) { + type RelationModel1 struct { + ID uint + } + type RelationModel2 struct { + ID uint + } + type RelationModel3 struct { + ID uint + RelationModel1ID uint + RelationModel1 *RelationModel1 + RelationModel2ID uint + RelationModel2 *RelationModel2 `gorm:"-:migration"` + } + + var err error + _ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{}) + + tx := DB.Session(&gorm.Session{}) + tx.IgnoreRelationshipsWhenMigrating = true + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should not be existed + _, err = findColumnType(&RelationModel1{}, "id") + if err == nil { + t.Errorf("RelationModel1 should not be migrated") + } + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } + + tx.IgnoreRelationshipsWhenMigrating = false + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should be existed + _, err = findColumnType(&RelationModel1{}, "id") + AssertEqual(t, nil, err) + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } +} From 16a272209adc54ef6623824dccde90b9f843a4d0 Mon Sep 17 00:00:00 2001 From: Haibo Date: Sun, 1 Jan 2023 22:14:28 +0800 Subject: [PATCH 101/231] fix(migrator): Tag default:'null' always causes field migration #5953 (#5954) * fix(migrator): Tag default:'null' always causes field migration #5953 * Update migrate_test.go * Update migrate_test.go * Update migrate_test.go Co-authored-by: Jinzhu --- migrator/migrator.go | 2 +- tests/migrate_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index ebd9bc12..90fbb461 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -493,7 +493,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check default value if !field.PrimaryKey { - currentDefaultNotNull := field.HasDefaultValue && !strings.EqualFold(field.DefaultValue, "NULL") + currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) dv, dvNotNull := columnType.DefaultValue() if dvNotNull && !currentDefaultNotNull { // defalut value -> null diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d5d129a8..7560faca 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1204,6 +1204,72 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { AssertEqual(t, nil, err) } +func TestMigrateDefaultNullString(t *testing.T) { + if DB.Dialector.Name() == "sqlserver" { + // sqlserver driver treats NULL and 'NULL' the same + t.Skip("skip sqlserver") + } + + type NullModel struct { + ID uint + Content string `gorm:"default:null"` + } + + type NullStringModel struct { + ID uint + Content string `gorm:"default:'null'"` + } + + tableName := "null_string_model" + + DB.Migrator().DropTable(tableName) + + err := DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + // default null -> 'null' + err = DB.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err := findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok := columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + // default 'null' -> 'null' + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + err = session.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + // default 'null' -> null + err = DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "") + AssertEqual(t, ok, false) +} + func TestMigrateIgnoreRelations(t *testing.T) { type RelationModel1 struct { ID uint From 4b768c8aff4335eec41b0b393a7978bda1e6194d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 1 Jan 2023 22:22:08 +0800 Subject: [PATCH 102/231] Upgrade tests deps --- logger/logger.go | 1 - tests/go.mod | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 29027205..aa0060bc 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -76,7 +76,6 @@ var ( SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, - ParameterizedQueries: true, Colorful: true, }) // Recorder Recorder logger records running SQL into a recorder instance diff --git a/tests/go.mod b/tests/go.mod index 3929b334..6ad6dd06 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,15 +5,13 @@ go 1.16 require ( github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgtype v1.13.0 // indirect github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/microsoft/go-mssqldb v0.19.0 // indirect - golang.org/x/crypto v0.4.0 // indirect gorm.io/driver/mysql v1.4.4 - gorm.io/driver/postgres v1.4.5 - gorm.io/driver/sqlite v1.4.3 + gorm.io/driver/postgres v1.4.6 + gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.2 ) From b0e13d95b486299d62f0e04d62cea154fb9ec051 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 1 Jan 2023 22:27:49 +0800 Subject: [PATCH 103/231] update github tests action --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 367f4ccd..5e9a1e63 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.19', '1.18', '1.17', '1.16'] + go: ['1.19', '1.18'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.19', '1.18', '1.17', '1.16'] + go: ['1.19', '1.18'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -86,7 +86,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.19', '1.18', '1.17', '1.16'] + go: ['1.19', '1.18'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.19', '1.18', '1.17', '1.16'] + go: ['1.19', '1.18'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From 3d91802b1d1bd5ad175ac43fac062fd9f8de98be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 2 Jan 2023 20:52:44 +0800 Subject: [PATCH 104/231] Fix unexpected alter table in auto migration, close #5942, #5943 --- migrator/migrator.go | 12 ++++++++---- schema/index.go | 1 + tests/go.mod | 2 +- tests/migrate_test.go | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 90fbb461..b8aaef2b 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -120,7 +120,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if err != nil { return err } - + var ( + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() + ) for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType @@ -157,7 +160,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } - for _, chk := range stmt.Schema.ParseCheckConstraints() { + for _, chk := range parseCheckConstraints { if !queryTx.Migrator().HasConstraint(value, chk.Name) { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err @@ -165,7 +168,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } - for _, idx := range stmt.Schema.ParseIndexes() { + for _, idx := range parseIndexes { if !queryTx.Migrator().HasIndex(value, idx.Name) { if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err @@ -430,7 +433,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy realDataType := strings.ToLower(columnType.DatabaseTypeName()) var ( - alterColumn, isSameType bool + alterColumn bool + isSameType = fullDataType == realDataType ) if !field.PrimaryKey { diff --git a/schema/index.go b/schema/index.go index 5003c742..c29623ad 100644 --- a/schema/index.go +++ b/schema/index.go @@ -129,6 +129,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) { } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + field.Unique = true settings["CLASS"] = "UNIQUE" } diff --git a/tests/go.mod b/tests/go.mod index 6ad6dd06..efa597a2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -13,7 +13,7 @@ require ( gorm.io/driver/postgres v1.4.6 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.1 - gorm.io/gorm v1.24.2 + gorm.io/gorm v1.24.3 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 7560faca..fcd0b5bd 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1270,6 +1270,38 @@ func TestMigrateDefaultNullString(t *testing.T) { AssertEqual(t, ok, false) } +func TestMigrateMySQLWithCustomizedTypes(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type MyTable struct { + Def string `gorm:"size:512;index:idx_def,unique"` + Abc string `gorm:"size:65000000"` + } + + DB.Migrator().DropTable("my_tables") + + sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))" + if err := DB.Exec(sql).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + + if err := session.AutoMigrate(&MyTable{}); err != nil { + t.Errorf("Failed, got error: %v", err) + } +} + func TestMigrateIgnoreRelations(t *testing.T) { type RelationModel1 struct { ID uint From 2bc913787b6d194aa4f72c8e4ddc64d62602ef21 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 2 Jan 2023 21:46:27 +0800 Subject: [PATCH 105/231] support implicit table alias, close #5840 #5940 --- chainable_api.go | 10 +++++++--- tests/go.mod | 3 +-- tests/soft_delete_test.go | 5 +++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 8a92a9e3..676fe914 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -55,7 +55,7 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) +var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) // Table specify the table you would like to run db operations // @@ -65,8 +65,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} - if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { - tx.Statement.Table = results[1] + if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { + if results[1] != "" { + tx.Statement.Table = results[1] + } else { + tx.Statement.Table = results[2] + } } } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} diff --git a/tests/go.mod b/tests/go.mod index efa597a2..2ba97179 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,13 +3,12 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/microsoft/go-mssqldb v0.19.0 // indirect - gorm.io/driver/mysql v1.4.4 + gorm.io/driver/mysql v1.4.5 gorm.io/driver/postgres v1.4.6 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.1 diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 9ac8da10..1f9a4786 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -39,6 +39,11 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } + sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) { + t.Errorf("Table with escape character, got %v", sql) + } + if DB.First(&User{}, "name = ?", user.Name).Error == nil { t.Errorf("Can't find a soft deleted record") } From baf1afa1fcb45b69a7c64c3fb82da7a0dd32bcfc Mon Sep 17 00:00:00 2001 From: Haibo Date: Wed, 11 Jan 2023 14:05:39 +0800 Subject: [PATCH 106/231] fix(schema): field is only unique when there is one unique index (#5974) --- schema/index.go | 7 +++++-- schema/index_test.go | 11 +++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/schema/index.go b/schema/index.go index c29623ad..f5ac5dd2 100644 --- a/schema/index.go +++ b/schema/index.go @@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index { } } } - + for _, index := range indexes { + if index.Class == "UNIQUE" && len(index.Fields) == 1 { + index.Fields[0].Field.Unique = true + } + } return indexes } @@ -129,7 +133,6 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) { } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { - field.Unique = true settings["CLASS"] = "UNIQUE" } diff --git a/schema/index_test.go b/schema/index_test.go index 1fe31cc1..890327de 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -65,7 +65,7 @@ func TestParseIndex(t *testing.T) { "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", @@ -81,7 +81,7 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", @@ -102,12 +102,12 @@ func TestParseIndex(t *testing.T) { }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}}, }, "type": { Name: "type", @@ -168,6 +168,9 @@ func TestParseIndex(t *testing.T) { if rf.Field.Name != ef.Field.Name { t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) } + if rf.Field.Unique != ef.Field.Unique { + t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique) + } for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { From 3d35ddba55c5777bd4867a50daff1e626d8fdb4a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 12 Jan 2023 16:52:17 +0800 Subject: [PATCH 107/231] Fix use table.* as select/omit columns --- README.md | 3 --- statement.go | 48 +++++++++++++++++++++--------------------------- tests/go.mod | 1 + 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 68fa6603..0c9ab74e 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) -[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) -[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) diff --git a/statement.go b/statement.go index 9f49d584..b99648fa 100644 --- a/statement.go +++ b/statement.go @@ -665,47 +665,41 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( results := map[string]bool{} notRestricted := false - // select columns - for _, column := range stmt.Selects { + processColumn := func(column string, result bool) { if stmt.Schema == nil { - results[column] = true + results[column] = result } else if column == "*" { - notRestricted = true + notRestricted = result for _, dbName := range stmt.Schema.DBNames { - results[dbName] = true + results[dbName] = result } } else if column == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = true + results[rel.Name] = result } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { - results[field.DBName] = true + results[field.DBName] = result } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { - results[matches[2]] = true + if matches[2] == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else { + results[matches[2]] = result + } } else { - results[column] = true + results[column] = result } } + // select columns + for _, column := range stmt.Selects { + processColumn(column, true) + } + // omit columns - for _, omit := range stmt.Omits { - if stmt.Schema == nil { - results[omit] = false - } else if omit == "*" { - for _, dbName := range stmt.Schema.DBNames { - results[dbName] = false - } - } else if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { - results[field.DBName] = false - } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { - results[matches[1]] = false - } else { - results[omit] = false - } + for _, column := range stmt.Omits { + processColumn(column, false) } if stmt.Schema != nil { diff --git a/tests/go.mod b/tests/go.mod index 2ba97179..acc0cf0e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,6 +8,7 @@ require ( github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/microsoft/go-mssqldb v0.19.0 // indirect + golang.org/x/crypto v0.5.0 // indirect gorm.io/driver/mysql v1.4.5 gorm.io/driver/postgres v1.4.6 gorm.io/driver/sqlite v1.4.4 From d834dd60b715422dc2a900fb2744f9c278a9830f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 19 Jan 2023 15:22:13 +0800 Subject: [PATCH 108/231] Remove unnecessary code --- schema/schema.go | 8 -------- tests/go.mod | 3 +-- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 21e71c21..b34383bd 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -246,14 +246,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.HasDefaultValue = true field.AutoIncrement = true } - case String: - if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { - if !field.HasDefaultValue || field.DefaultValueInterface != nil { - schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) - } - - field.HasDefaultValue = true - } } } diff --git a/tests/go.mod b/tests/go.mod index acc0cf0e..251aabb3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,12 +7,11 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect - github.com/microsoft/go-mssqldb v0.19.0 // indirect golang.org/x/crypto v0.5.0 // indirect gorm.io/driver/mysql v1.4.5 gorm.io/driver/postgres v1.4.6 gorm.io/driver/sqlite v1.4.4 - gorm.io/driver/sqlserver v1.4.1 + gorm.io/driver/sqlserver v1.4.2 gorm.io/gorm v1.24.3 ) From cfbcedbf036931d134a030b5ccc2de7f48f1a7c3 Mon Sep 17 00:00:00 2001 From: qiankunli Date: Wed, 1 Feb 2023 14:40:55 +0800 Subject: [PATCH 109/231] fix: support zeroValue tag on DeletedAt (#6011) * fix: support zeroValue tag on DeletedAt Signed-off-by: qiankunli * Update soft_delete_test.go * Update tests_test.go * Update soft_delete.go --------- Signed-off-by: qiankunli Co-authored-by: Jinzhu --- soft_delete.go | 27 +++++++++++---- tests/soft_delete_test.go | 69 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/soft_delete.go b/soft_delete.go index 6d646288..5673d3b8 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -6,6 +6,7 @@ import ( "encoding/json" "reflect" + "github.com/jinzhu/now" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteQueryClause{Field: f}} + return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +func parseZeroValueTag(f *schema.Field) sql.NullString { + if v, ok := f.TagSettings["ZEROVALUE"]; ok { + if _, err := now.Parse(v); err == nil { + return sql.NullString{String: v, Valid: true} + } + } + return sql.NullString{Valid: false} } type SoftDeleteQueryClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteQueryClause) Name() string { @@ -78,18 +89,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } stmt.AddClause(clause.Where{Exprs: []clause.Expression{ - clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, }}) stmt.Clauses["soft_delete_enabled"] = clause.Clause{} } } func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteUpdateClause{Field: f}} + return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteUpdateClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteUpdateClause) Name() string { @@ -109,11 +121,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { } func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteDeleteClause{Field: f}} + return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteDeleteClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteDeleteClause) Name() string { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 1f9a4786..179ae426 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -7,6 +7,7 @@ import ( "regexp" "testing" + "github.com/jinzhu/now" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -98,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) { t.Fatalf("invalid sql generated, got %v", actualSQL) } } + +func TestSoftDeleteZeroValue(t *testing.T) { + type SoftDeleteBook struct { + ID uint + Name string + Pages uint + DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"` + } + DB.Migrator().DropTable(&SoftDeleteBook{}) + if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil { + t.Fatalf("failed to auto migrate soft delete table") + } + + book := SoftDeleteBook{Name: "jinzhu", Pages: 10} + DB.Save(&book) + + var count int64 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + var pages uint + if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages) + } + + if err := DB.Delete(&book).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + zeroTime, _ := now.Parse("1970-01-01 00:00:01") + if book.DeletedAt.Time.Equal(zeroTime) { + t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt) + } + + if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + pages = 0 + if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 { + t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err) + } + + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + pages = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages) + } + + DB.Unscoped().Delete(&book) + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } +} From 4d6b70ec88dbff3d4a5e43b284c7b5b624915844 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Feb 2023 17:15:08 +0800 Subject: [PATCH 110/231] Allow modify statement from dest --- callbacks.go | 4 ++++ clause/clause.go | 1 + 2 files changed, 5 insertions(+) diff --git a/callbacks.go b/callbacks.go index ebebf79d..de979e45 100644 --- a/callbacks.go +++ b/callbacks.go @@ -93,6 +93,10 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } + if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/clause/clause.go b/clause/clause.go index de19f2e3..1354fc05 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -20,6 +20,7 @@ type Builder interface { Writer WriteQuoted(field interface{}) AddVar(Writer, ...interface{}) + AddError(error) error } // Clause From e1f46eb802e7a73c9cc04241c3077dbe9021cd51 Mon Sep 17 00:00:00 2001 From: chyroc Date: Thu, 2 Feb 2023 17:54:51 +0800 Subject: [PATCH 111/231] fix: ignore nil query (#6021) --- statement.go | 3 +++ statement_test.go | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/statement.go b/statement.go index b99648fa..08165293 100644 --- a/statement.go +++ b/statement.go @@ -311,6 +311,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) for idx, arg := range args { + if arg == nil { + continue + } if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } diff --git a/statement_test.go b/statement_test.go index 761daf37..648bc875 100644 --- a/statement_test.go +++ b/statement_test.go @@ -35,6 +35,13 @@ func TestWhereCloneCorruption(t *testing.T) { } } +func TestNilCondition(t *testing.T) { + s := new(Statement) + if len(s.BuildCondition(nil)) != 0 { + t.Errorf("Nil condition should be empty") + } +} + func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ "table.name": {"table", "name"}, From 878ac51e983858bce556877fa72227cb76643155 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 8 Feb 2023 13:40:41 +0800 Subject: [PATCH 112/231] fix:throw model value required error (#6031) * fix:throw model value required error * chore:ingore typecheck * chore:ingore errcheck * refactor: use other error * chore: gofumpt style --- callbacks/row.go | 2 +- errors.go | 2 ++ statement.go | 2 ++ tests/query_test.go | 14 ++++++++++++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/callbacks/row.go b/callbacks/row.go index 56be742e..beaa189e 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -7,7 +7,7 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) - if db.DryRun { + if db.DryRun || db.Error != nil { return } diff --git a/errors.go b/errors.go index 49cbfe64..0f486c5e 100644 --- a/errors.go +++ b/errors.go @@ -21,6 +21,8 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrModelAccessibleFieldsRequired model accessible fields required + ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") // ErrInvalidData unsupported data ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver diff --git a/statement.go b/statement.go index 08165293..bc959f0b 100644 --- a/statement.go +++ b/statement.go @@ -120,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { write(v.Raw, v.Name) diff --git a/tests/query_test.go b/tests/query_test.go index fa8f09e8..88e93c77 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1366,3 +1366,17 @@ func TestQueryResetNullValue(t *testing.T) { AssertEqual(t, q1, qs[0]) AssertEqual(t, q2, qs[1]) } + +func TestQueryError(t *testing.T) { + type P struct{} + var p1 P + err := DB.Take(&p1, 1).Error + AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired) + + var p2 interface{} + + err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{ + Table: clause.CurrentTable, Name: clause.PrimaryKey, + }, Value: 1}).Scan(&p2).Error + AssertEqual(t, err, gorm.ErrModelValueRequired) +} From 02b7e26f6b5dcdc49797cc44c26a255a69f3aff3 Mon Sep 17 00:00:00 2001 From: Cheese Date: Wed, 8 Feb 2023 16:29:09 +0800 Subject: [PATCH 113/231] feat: add tidb integration test cases (#6014) * feat: support tidb integration test * feat: update the mysql driver version to test --- .github/workflows/tests.yml | 33 +++++++ tests/associations_belongs_to_test.go | 1 + tests/associations_many2many_test.go | 2 + tests/associations_test.go | 4 + tests/docker-compose.yml | 5 + tests/go.mod | 4 +- tests/helper_test.go | 11 +++ tests/migrate_test.go | 132 ++++++++++++++++++++++++++ tests/sql_builder_test.go | 2 +- tests/tests_all.sh | 2 +- tests/tests_test.go | 7 ++ 11 files changed, 199 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5e9a1e63..cfe8e56f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -167,3 +167,36 @@ jobs: - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + + tidb: + strategy: + matrix: + dbversion: [ 'v6.5.0' ] + go: [ '1.19', '1.18' ] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + steps: + - name: Setup TiDB + uses: Icemap/tidb-action@main + with: + port: 9940 + version: ${{matrix.dbversion}} + + - name: Set up Go 1.x + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index a1f014d9..99e8aa79 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -138,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) { unexistCompanyID := company.ID + 9999999 user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} if err := DB.Create(&user).Error; err == nil { + tidbSkip(t, "not support the foreign key feature") t.Errorf("should have gotten foreign key violation error") } } diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 7b45befb..4ba31f90 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -95,6 +95,8 @@ func TestMany2ManyAssociation(t *testing.T) { } func TestMany2ManyOmitAssociations(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + user := *GetUser("many2many_omit_associations", Config{Languages: 2}) if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { diff --git a/tests/associations_test.go b/tests/associations_test.go index 4c9076da..4e8862e5 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -71,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) { } func TestForeignKeyConstraints(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + type Profile struct { ID uint Name string @@ -126,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) { } func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + type Profile struct { ID uint Name string diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 9ab4ddb6..0e5673fb 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -29,3 +29,8 @@ services: - MSSQL_DB=gorm - MSSQL_USER=gorm - MSSQL_PASSWORD=LoremIpsum86 + tidb: + image: 'pingcap/tidb:v6.5.0' + ports: + - 9940:4000 + command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & diff --git a/tests/go.mod b/tests/go.mod index 251aabb3..69d6cf87 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,11 +8,11 @@ require ( github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect golang.org/x/crypto v0.5.0 // indirect - gorm.io/driver/mysql v1.4.5 + gorm.io/driver/mysql v1.4.6 gorm.io/driver/postgres v1.4.6 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.2 - gorm.io/gorm v1.24.3 + gorm.io/gorm v1.24.5 ) replace gorm.io/gorm => ../ diff --git a/tests/helper_test.go b/tests/helper_test.go index d1af0739..d40fa5ce 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "os" "sort" "strconv" "strings" @@ -235,3 +236,13 @@ func CheckUser(t *testing.T, user User, expect User) { } }) } + +func tidbSkip(t *testing.T, reason string) { + if isTiDB() { + t.Skipf("This test case skipped, because of TiDB '%s'", reason) + } +} + +func isTiDB() bool { + return os.Getenv("GORM_DIALECT") == "tidb" +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index fcd0b5bd..489da976 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -374,7 +374,137 @@ func TestMigrateIndexes(t *testing.T) { } } +func TestTiDBMigrateColumns(t *testing.T) { + if !isTiDB() { + t.Skip() + } + + // TiDB can't change column constraint and has auto_random feature + type ColumnStruct struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type ColumnStruct2 struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"comment:my code2;default:hello"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { + t.Fatalf("no error should happened when alter column, but got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); !ok || length != 100 { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !ok || v != "my age" { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + // Code2 string `gorm:"comment:my code2;default:hello"` + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + } + } + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") + } +} + func TestMigrateColumns(t *testing.T) { + tidbSkip(t, "use another test case") + sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" @@ -853,6 +983,8 @@ func TestUniqueColumn(t *testing.T) { AssertEqual(t, "", value) AssertEqual(t, false, ok) + tidbSkip(t, "can't change column constraint") + // null -> empty string err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) if err != nil { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 0fbd6118..022e0495 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -29,7 +29,7 @@ func TestRow(t *testing.T) { } table := "gorm.users" - if DB.Dialector.Name() != "mysql" { + if DB.Dialector.Name() != "mysql" || isTiDB() { table = "users" // other databases doesn't support select with `database.table` } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 5b9bae97..ee9e7675 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,6 +1,6 @@ #!/bin/bash -e -dialects=("sqlite" "mysql" "postgres" "sqlserver") +dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. diff --git a/tests/tests_test.go b/tests/tests_test.go index dcba3cbf..90eb847f 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -21,6 +21,7 @@ var ( mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ) func init() { @@ -80,6 +81,12 @@ func OpenTestConnection() (db *gorm.DB, err error) { dbDSN = sqlserverDSN } db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + case "tidb": + log.Println("testing tidb...") + if dbDSN == "" { + dbDSN = tidbDSN + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) From 532e9cf4ccce927249bcb102c09e4a9093aae4fe Mon Sep 17 00:00:00 2001 From: Michael Anstis Date: Sat, 18 Feb 2023 01:06:43 +0000 Subject: [PATCH 114/231] Issue 6054: Unscoped not working with PreLoad on Joins (#6058) * Issue 6054: Unscoped not working with PreLoad on Joins * Formatting --------- Co-authored-by: Michael Anstis --- callbacks/query.go | 1 + clause/select_test.go | 12 +++++++----- migrator/migrator.go | 4 +--- model.go | 7 ++++--- schema/field.go | 2 +- schema/relationship.go | 23 +++++++++++----------- schema/serializer.go | 9 +++------ tests/connpool_test.go | 8 +++++--- tests/embedded_struct_test.go | 1 - tests/helper_test.go | 36 +++++++++++++++++++++++++++++----- tests/migrate_test.go | 3 +-- tests/preload_test.go | 37 +++++++++++++++++++++++++++++++++++ tests/table_test.go | 5 +++-- 13 files changed, 106 insertions(+), 42 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 97fe8a49..9a6d4f4a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -257,6 +257,7 @@ func Preload(db *gorm.DB) { return } preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + preloadDB.Statement.Unscoped = db.Statement.Unscoped for _, name := range preloadNames { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { diff --git a/clause/select_test.go b/clause/select_test.go index 18bc2693..9c11b90d 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -49,16 +49,18 @@ func TestSelect(t *testing.T) { Exprs: []clause.Expression{ clause.Expr{ SQL: "? as name", - Vars: []interface{}{clause.Eq{ - Column: clause.Column{Name: "age"}, - Value: 18, - }, + Vars: []interface{}{ + clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, }, }, }, }, }, clause.From{}}, - "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + "SELECT `age` = ? as name FROM `users`", + []interface{}{18}, }, } diff --git a/migrator/migrator.go b/migrator/migrator.go index b8aaef2b..12c2df46 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,9 +16,7 @@ import ( "gorm.io/gorm/schema" ) -var ( - regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) -) +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // Migrator m struct type Migrator struct { diff --git a/model.go b/model.go index 3334d17c..fa705df1 100644 --- a/model.go +++ b/model.go @@ -4,9 +4,10 @@ import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } +// +// type User struct { +// gorm.Model +// } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time diff --git a/schema/field.go b/schema/field.go index 1589d984..59151878 100644 --- a/schema/field.go +++ b/schema/field.go @@ -174,7 +174,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = String field.Serializer = v } else { - var serializerName = field.TagSettings["JSON"] + serializerName := field.TagSettings["JSON"] if serializerName == "" { serializerName = field.TagSettings["SERIALIZER"] } diff --git a/schema/relationship.go b/schema/relationship.go index 9436f283..b33b94a7 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -123,16 +123,17 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` -// type User struct { -// Toys []Toy `gorm:"polymorphic:Owner;"` -// } -// type Pet struct { -// Toy Toy `gorm:"polymorphic:Owner;"` -// } -// type Toy struct { -// OwnerID int -// OwnerType string -// } +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { relation.Polymorphic = &Polymorphic{ Value: schema.Table, @@ -427,7 +428,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu foreignFields = append(foreignFields, f) } } else { - var primarySchemaName = primarySchema.Name + primarySchemaName := primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name } diff --git a/schema/serializer.go b/schema/serializer.go index 9a6aa4fc..397edff0 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -70,8 +70,7 @@ type SerializerValuerInterface interface { } // JSONSerializer json serializer -type JSONSerializer struct { -} +type JSONSerializer struct{} // Scan implements serializer interface func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -110,8 +109,7 @@ func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value } // UnixSecondSerializer json serializer -type UnixSecondSerializer struct { -} +type UnixSecondSerializer struct{} // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -141,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } // GobSerializer gob serializer -type GobSerializer struct { -} +type GobSerializer struct{} // Scan implements serializer interface func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 42e029bc..e0e1c771 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error { } // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. -// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { -// return c.db.BeginTx(ctx, opts) -// } +// +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index ae69baca..63ec53ee 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -94,7 +94,6 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("expected author %s got %s", want, post.Author.Name) } } - } func TestEmbeddedPointerTypeStruct(t *testing.T) { diff --git a/tests/helper_test.go b/tests/helper_test.go index d40fa5ce..c34e357c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) @@ -74,10 +76,18 @@ func GetUser(name string, config Config) *User { return &user } +func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, true) +} + func CheckPet(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, false) +} + +func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) { if pet.ID != 0 { var newPet Pet - if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") @@ -94,10 +104,18 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) { } } +func CheckUserUnscoped(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, true) +} + func CheckUser(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, false) +} + +func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if user.ID != 0 { var newUser User - if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") @@ -114,7 +132,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Account's foreign key should be saved") } else { var account Account - DB.First(&account, "user_id = ?", user.ID) + db(unscoped).First(&account, "user_id = ?", user.ID) AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") } } @@ -137,7 +155,7 @@ func CheckUser(t *testing.T, user User, expect User) { if pet == nil || expect.Pets[idx] == nil { t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) } else { - CheckPet(t, *pet, *expect.Pets[idx]) + doCheckPet(t, *pet, *expect.Pets[idx], unscoped) } } }) @@ -174,7 +192,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Manager's foreign key should be saved") } else { var manager User - DB.First(&manager, "id = ?", *user.ManagerID) + db(unscoped).First(&manager, "id = ?", *user.ManagerID) AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } @@ -246,3 +264,11 @@ func tidbSkip(t *testing.T, reason string) { func isTiDB() bool { return os.Getenv("GORM_DIALECT") == "tidb" } + +func db(unscoped bool) *gorm.DB { + if unscoped { + return DB.Unscoped() + } else { + return DB + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 489da976..8794ccba 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -75,7 +75,6 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } - } func TestAutoMigrateInt8PG(t *testing.T) { @@ -1267,7 +1266,7 @@ func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { } func TestMigrateDonotAlterColumn(t *testing.T) { - var wrapMockMigrator = func(m gorm.Migrator) mockMigrator { + wrapMockMigrator := func(m gorm.Migrator) mockMigrator { return mockMigrator{ Migrator: m, } diff --git a/tests/preload_test.go b/tests/preload_test.go index cb4343ec..e7223b3e 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -269,3 +269,40 @@ func TestPreloadWithDiffModel(t *testing.T) { CheckUser(t, user, result.User) } + +func TestNestedPreloadWithUnscoped(t *testing.T) { + user := *GetUser("nested_preload", Config{Pets: 1}) + pet := user.Pets[0] + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)} + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + DB.Delete(&pet) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + if len(user3.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + if len(user4.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user5 User + DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID) + CheckUserUnscoped(t, user5, user) + + var user6 *User + DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) + CheckUserUnscoped(t, *user6, user) +} diff --git a/tests/table_test.go b/tests/table_test.go index f538c691..fa569d32 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -158,10 +158,11 @@ func (UserWithTableNamer) TableName(namer schema.Namer) string { } func TestTableWithNamer(t *testing.T) { - var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + db, _ := gorm.Open(tests.DummyDialector{}, &gorm.Config{ NamingStrategy: schema.NamingStrategy{ TablePrefix: "t_", - }}) + }, + }) sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) From aa89736db2fd175391d23ef02406414125d21067 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 18 Feb 2023 09:13:36 +0800 Subject: [PATCH 115/231] fix: miss join type (#6056) --- chainable_api.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 676fe914..a85235e0 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -253,7 +253,10 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) if len(args) == 1 { if db, ok := args[0].(*DB); ok { - j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} + j := join{ + Name: query, Conds: args, Selects: db.Statement.Selects, + Omits: db.Statement.Omits, JoinType: joinType, + } if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { j.On = &where } From 42fc75cb2ced9a27b8baecb08ec33976096007c0 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sat, 18 Feb 2023 09:19:24 +0800 Subject: [PATCH 116/231] fix: association concurrently appending (#6044) * fix: association concurrently appending * fix: fix unit test * fix: fix gofumpt --- association.go | 8 ++++-- tests/associations_many2many_test.go | 40 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/association.go b/association.go index 06229caa..6719a1d0 100644 --- a/association.go +++ b/association.go @@ -353,9 +353,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) } appendToFieldValues := func(ev reflect.Value) { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 4ba31f90..845c16af 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -1,9 +1,12 @@ package tests_test import ( + "fmt" + "sync" "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -353,3 +356,40 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, user2, findUser2) } + +func TestConcurrentMany2ManyAssociation(t *testing.T) { + db, err := OpenTestConnection() + if err != nil { + t.Fatalf("open test connection failed, err: %+v", err) + } + + count := 3 + + var languages []Language + for i := 0; i < count; i++ { + language := Language{Code: fmt.Sprintf("consurrent %d", i)} + db.Create(&language) + languages = append(languages, language) + } + + user := User{} + db.Create(&user) + db.Preload("Languages").FirstOrCreate(&user) + + var wg sync.WaitGroup + for i := 0; i < count; i++ { + wg.Add(1) + go func(user User, language Language) { + err := db.Model(&user).Association("Languages").Append(&language) + AssertEqual(t, err, nil) + + wg.Done() + }(user, languages[i]) + } + wg.Wait() + + var find User + err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error + AssertEqual(t, err, nil) + AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") +} From e66a059b823218ec6d7efc765f67d788bb900f75 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sat, 18 Feb 2023 09:20:29 +0800 Subject: [PATCH 117/231] fix: update panic if model is not ptr (#6037) * fix: update panic if model is not ptr * fix: update panic if model is not ptr * fix: update panic if model is not ptr * fix: raise an error if the value is not addressable * fix: return --- callbacks/callmethod.go | 13 +++++++++-- callbacks/update.go | 4 +++- schema/utils.go | 2 +- tests/hooks_test.go | 52 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 4 deletions(-) diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index bcaa03f3..fb900037 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } db.Statement.CurDestIndex++ } case reflect.Struct: - fc(db.Statement.ReflectValue.Addr().Interface(), tx) + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } } } } diff --git a/callbacks/update.go b/callbacks/update.go index b596df9a..fe6f0994 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -137,7 +137,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } } } case reflect.Struct: diff --git a/schema/utils.go b/schema/utils.go index acf1a739..65d012e5 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -133,7 +133,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, for i := 0; i < reflectValue.Len(); i++ { elem := reflectValue.Index(i) elemKey := elem.Interface() - if elem.Kind() != reflect.Ptr { + if elem.Kind() != reflect.Ptr && elem.CanAddr() { elemKey = elem.Addr().Interface() } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 8e964fd8..0753dd0b 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -514,3 +514,55 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) } } + +type Product5 struct { + gorm.Model + Name string +} + +var beforeUpdateCall int + +func (p *Product5) BeforeUpdate(*gorm.DB) error { + beforeUpdateCall = beforeUpdateCall + 1 + return nil +} + +func TestUpdateCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product5{}) + DB.AutoMigrate(&Product5{}) + + p := Product5{Name: "unique_code"} + DB.Model(&Product5{}).Create(&p) + + err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should be called") + } + + err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should not be called") + } + + err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should be called") + } + + err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should not be called") + } +} From 04cbd956ebed5fec1b61a819a3f7494c00d276b3 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 18 Feb 2023 09:21:07 +0800 Subject: [PATCH 118/231] test: pgsql migrate unique index (#6028) --- tests/migrate_test.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 8794ccba..5a220ca4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -256,9 +256,10 @@ func TestMigrateWithIndexComment(t *testing.T) { func TestMigrateWithUniqueIndex(t *testing.T) { type UserWithUniqueIndex struct { - ID int - Name string `gorm:"size:20;index:idx_name,unique"` - Date time.Time `gorm:"index:idx_name,unique"` + ID int + Name string `gorm:"size:20;index:idx_name,unique"` + Date time.Time `gorm:"index:idx_name,unique"` + UName string `gorm:"uniqueIndex;size:255"` } DB.Migrator().DropTable(&UserWithUniqueIndex{}) @@ -269,6 +270,18 @@ func TestMigrateWithUniqueIndex(t *testing.T) { if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { t.Errorf("Failed to find created index") } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } + + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } } func TestMigrateTable(t *testing.T) { From 391c961c7fafcf89cf89e904a97b01493411bfa0 Mon Sep 17 00:00:00 2001 From: Jiepeng Cao Date: Mon, 27 Feb 2023 15:39:02 +0800 Subject: [PATCH 119/231] quotes on docker-compose.yml ports (#6089) --- tests/docker-compose.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 0e5673fb..866a4d62 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -4,7 +4,7 @@ services: mysql: image: 'mysql/mysql-server:latest' ports: - - 9910:3306 + - "9910:3306" environment: - MYSQL_DATABASE=gorm - MYSQL_USER=gorm @@ -13,7 +13,7 @@ services: postgres: image: 'postgres:latest' ports: - - 9920:5432 + - "9920:5432" environment: - TZ=Asia/Shanghai - POSTGRES_DB=gorm @@ -22,7 +22,7 @@ services: mssql: image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - - 9930:1433 + - "9930:1433" environment: - ACCEPT_EULA=Y - SA_PASSWORD=LoremIpsum86 @@ -32,5 +32,5 @@ services: tidb: image: 'pingcap/tidb:v6.5.0' ports: - - 9940:4000 + - "9940:4000" command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & From a80707de9e33dffb5c136a16be837209c6502215 Mon Sep 17 00:00:00 2001 From: black-06 Date: Mon, 27 Feb 2023 15:43:10 +0800 Subject: [PATCH 120/231] Create and drop view (#6097) * create view * add comment * fix test * check param and add comment --- errors.go | 2 ++ migrator.go | 6 +++--- migrator/migrator.go | 36 +++++++++++++++++++++++++++++++++--- tests/migrate_test.go | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/errors.go b/errors.go index 0f486c5e..5bfd0f82 100644 --- a/errors.go +++ b/errors.go @@ -23,6 +23,8 @@ var ( ErrModelValueRequired = errors.New("model value required") // ErrModelAccessibleFieldsRequired model accessible fields required ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") + // ErrSubQueryRequired sub query required + ErrSubQueryRequired = errors.New("sub query required") // ErrInvalidData unsupported data ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver diff --git a/migrator.go b/migrator.go index 882fc4cc..9c7cc2c4 100644 --- a/migrator.go +++ b/migrator.go @@ -30,9 +30,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { // ViewOption view option type ViewOption struct { - Replace bool - CheckOption string - Query *DB + Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` + CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` + Query *DB // required subquery. } // ColumnType column type interface diff --git a/migrator/migrator.go b/migrator/migrator.go index 12c2df46..389ce008 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -557,14 +557,44 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } -// CreateView create view +// CreateView create view from Query in gorm.ViewOption. +// Query in gorm.ViewOption is a [subquery] +// +// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 +// q := DB.Model(&User{}).Where("age > ?", 20) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) +// +// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION +// q := DB.Model(&User{}) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) +// +// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery func (m Migrator) CreateView(name string, option gorm.ViewOption) error { - return gorm.ErrNotImplemented + if option.Query == nil { + return gorm.ErrSubQueryRequired + } + + sql := new(strings.Builder) + sql.WriteString("CREATE ") + if option.Replace { + sql.WriteString("OR REPLACE ") + } + sql.WriteString("VIEW ") + m.QuoteTo(sql, name) + sql.WriteString(" AS ") + + m.DB.Statement.AddVar(sql, option.Query) + + if option.CheckOption != "" { + sql.WriteString(" ") + sql.WriteString(option.CheckOption) + } + return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error } // DropView drop view func (m Migrator) DropView(name string) error { - return gorm.ErrNotImplemented + return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5a220ca4..11a0afda 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1509,3 +1509,36 @@ func TestMigrateIgnoreRelations(t *testing.T) { t.Errorf("RelationModel2 should not be migrated") } } + +func TestMigrateView(t *testing.T) { + DB.Save(GetUser("joins-args-db", Config{Pets: 2})) + + if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + t.Fatalf("no view should be created, got %v", err) + } + + query := DB.Model(&User{}). + Select("users.id as users_id, users.name as users_name, pets.id as pets_id, pets.name as pets_name"). + Joins("inner join pets on pets.user_id = users.id") + + if err := DB.Migrator().CreateView("users_pets", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + + var count int64 + if err := DB.Table("users_pets").Count(&count).Error; err != nil { + t.Fatalf("should found created view") + } + + if err := DB.Migrator().DropView("users_pets"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } + + query = DB.Model(&User{}).Where("age > ?", 20) + if err := DB.Migrator().CreateView("users_view", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + if err := DB.Migrator().DropView("users_view"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } +} From 877cc9148f95552f51891d45d588af799033ceb8 Mon Sep 17 00:00:00 2001 From: Jiepeng Cao Date: Mon, 27 Feb 2023 15:44:35 +0800 Subject: [PATCH 121/231] Remove redundant code (#6087) --- tests/query_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/query_test.go b/tests/query_test.go index 88e93c77..b6bd0736 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -218,7 +218,7 @@ func TestFind(t *testing.T) { // test array var models2 [3]User - if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) } else { for idx, user := range users { @@ -230,7 +230,7 @@ func TestFind(t *testing.T) { // test smaller array var models3 [2]User - if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) } else { for idx, user := range users[:2] { From f3874339efd829d9841ad8fb6b50d7c2059153d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 2 Mar 2023 17:22:42 +0800 Subject: [PATCH 122/231] Fix Save with stress tests --- finisher_api.go | 11 +++++------ go.mod | 2 +- go.sum | 2 ++ tests/go.mod | 7 ++++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 39d9fca3..f16d4f43 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -101,14 +101,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx = tx.callbacks.Update().Execute(tx) + updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) - if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { - return tx.Create(value) - } + if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { + return tx.Create(value) } + + return updateTx } return diff --git a/go.mod b/go.mod index 03f84379..85e4242a 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.16 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.4 + github.com/jinzhu/now v1.1.5 ) diff --git a/go.sum b/go.sum index 50fbba2f..fb4240eb 100644 --- a/go.sum +++ b/go.sum @@ -2,3 +2,5 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/tests/go.mod b/tests/go.mod index 69d6cf87..b2d5ca97 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,12 +4,13 @@ go 1.16 require ( github.com/google/uuid v1.3.0 + github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect - golang.org/x/crypto v0.5.0 // indirect - gorm.io/driver/mysql v1.4.6 - gorm.io/driver/postgres v1.4.6 + github.com/microsoft/go-mssqldb v0.20.0 // indirect + gorm.io/driver/mysql v1.4.7 + gorm.io/driver/postgres v1.4.8 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.2 gorm.io/gorm v1.24.5 From 85eaf9eeda11e4c4c9aa24bf660325e364ca6e6b Mon Sep 17 00:00:00 2001 From: Saeid Kanishka Date: Mon, 6 Mar 2023 07:03:31 +0100 Subject: [PATCH 123/231] feat: Unique Constraint Violation error translator for different drivers (#6004) * feat: duplicated key error translator for different drivers * test: removed the dependency * test: fixed broken tests * refactor: added ErrorTransltor interface * style: applied styler --------- Co-authored-by: Saeid Saeidee --- errors.go | 2 ++ gorm.go | 4 ++++ interfaces.go | 4 ++++ tests/error_translator_test.go | 19 +++++++++++++++++++ utils/tests/dummy_dialecter.go | 8 +++++++- 5 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 tests/error_translator_test.go diff --git a/errors.go b/errors.go index 5bfd0f82..57e3fc5e 100644 --- a/errors.go +++ b/errors.go @@ -45,4 +45,6 @@ var ( ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") // ErrPreloadNotAllowed preload is not allowed when count is used ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") ) diff --git a/gorm.go b/gorm.go index 37595ddd..b5d98196 100644 --- a/gorm.go +++ b/gorm.go @@ -347,6 +347,10 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } + if db.Error == nil { db.Error = err } else if err != nil { diff --git a/interfaces.go b/interfaces.go index cf9e07b9..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -86,3 +86,7 @@ type Rows interface { Err() error Close() error } + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go new file mode 100644 index 00000000..2e472e34 --- /dev/null +++ b/tests/error_translator_test.go @@ -0,0 +1,19 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +func TestDialectorWithErrorTranslatorSupport(t *testing.T) { + translatedErr := errors.New("translated error") + db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) + + err := db.AddError(errors.New("some random error")) + if !errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) + } +} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index c89b944a..a2d9c33d 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -8,7 +8,9 @@ import ( "gorm.io/gorm/schema" ) -type DummyDialector struct{} +type DummyDialector struct { + TranslatedErr error +} func (DummyDialector) Name() string { return "dummy" @@ -92,3 +94,7 @@ func (DummyDialector) Explain(sql string, vars ...interface{}) string { func (DummyDialector) DataTypeOf(*schema.Field) string { return "" } + +func (d DummyDialector) Translate(err error) error { + return d.TranslatedErr +} From e9f25c73ee6afd560880db4537edf9ca24f2bc4a Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 10 Mar 2023 16:35:26 +0800 Subject: [PATCH 124/231] fix: on confilct with default null (#6129) * fix: on confilct with default null * Update create.go --------- Co-authored-by: Jinzhu --- callbacks/create.go | 4 +++- tests/create_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 0fe1dc93..f0b78139 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -302,7 +303,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { diff --git a/tests/create_test.go b/tests/create_test.go index 274a7f48..e8da91ff 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -547,3 +547,39 @@ func TestFirstOrCreateRowsAffected(t *testing.T) { t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) } } + +func TestCreateOnConfilctWithDefalutNull(t *testing.T) { + type OnConfilctUser struct { + ID string + Name string `gorm:"default:null"` + Email string + Mobile string `gorm:"default:'133xxxx'"` + } + + err := DB.Migrator().DropTable(&OnConfilctUser{}) + AssertEqual(t, err, nil) + err = DB.AutoMigrate(&OnConfilctUser{}) + AssertEqual(t, err, nil) + + u := OnConfilctUser{ + ID: "on-confilct-user-id", + Name: "on-confilct-user-name", + Email: "on-confilct-user-email", + Mobile: "on-confilct-user-mobile", + } + err = DB.Create(&u).Error + AssertEqual(t, err, nil) + + u.Name = "on-confilct-user-name-2" + u.Email = "on-confilct-user-email-2" + u.Mobile = "" + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error + AssertEqual(t, err, nil) + + var u2 OnConfilctUser + err = DB.Where("id = ?", u.ID).First(&u2).Error + AssertEqual(t, err, nil) + AssertEqual(t, u2.Name, "on-confilct-user-name-2") + AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Mobile, "133xxxx") +} From 1643a36260cbc5bcc6e4abab6489325b64c57e7a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Mar 2023 10:48:14 +0800 Subject: [PATCH 125/231] Fix possible concurrency problem for serializer --- schema/field.go | 13 ++++++++++--- tests/go.mod | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 59151878..00beb067 100644 --- a/schema/field.go +++ b/schema/field.go @@ -916,6 +916,8 @@ func (field *Field) setupValuerAndSetter() { sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() } + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { if s.fieldValue != nil { @@ -923,11 +925,12 @@ func (field *Field) setupValuerAndSetter() { } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } else if sameType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) } } else { err = oldFieldSetter(ctx, value, v) @@ -939,11 +942,15 @@ func (field *Field) setupValuerAndSetter() { func (field *Field) setupNewValuePool() { if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.NewValuePool = &sync.Pool{ New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) return &serializer{ Field: field, - Serializer: field.Serializer, + Serializer: si.Interface().(SerializerInterface), } }, } diff --git a/tests/go.mod b/tests/go.mod index b2d5ca97..e970c9f5 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,12 @@ require ( github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.16 // indirect github.com/microsoft/go-mssqldb v0.20.0 // indirect + golang.org/x/crypto v0.7.0 // indirect gorm.io/driver/mysql v1.4.7 gorm.io/driver/postgres v1.4.8 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.2 - gorm.io/gorm v1.24.5 + gorm.io/gorm v1.24.6 ) replace gorm.io/gorm => ../ From ed474152b16789d61e535df336af2526c016629c Mon Sep 17 00:00:00 2001 From: Truong Nguyen Date: Fri, 10 Mar 2023 17:50:03 +0900 Subject: [PATCH 126/231] Fix: Composite primary key with auto-increment value returns 0 after insert (#6127) * Fix #4930 workaround for databases that support auto-increment in composite primary key. * Add test for composite key with auto-increment. * schema.go: use field.AutoIncrement instead of field.TagSettings["AUTOINCREMENT"], add test to check autoincrement:false create_test.go: remove unused code: drop table CompositeKeyProduct --------- Co-authored-by: Jinzhu --- schema/schema.go | 14 ++++++++++++-- schema/schema_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ tests/create_test.go | 29 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index b34383bd..17bdb25e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -221,8 +221,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { - schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } } for _, field := range schema.PrimaryFields { diff --git a/schema/schema_test.go b/schema/schema_test.go index 8a752fb7..5bc0fb83 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { }) } } + +func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { + type Product struct { + ProductID uint `gorm:"primaryKey;autoIncrement"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + type ProductNonAutoIncrement struct { + ProductID uint `gorm:"primaryKey;autoIncrement:false"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + prioritizedPrimaryField := schema.Field{ + Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"}, + } + + product.Fields = []*schema.Field{product.PrioritizedPrimaryField} + + checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + + productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err) + } + + if productNonAutoIncrement.PrioritizedPrimaryField != nil { + t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") + } +} diff --git a/tests/create_test.go b/tests/create_test.go index e8da91ff..75aa8cba 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -548,6 +548,35 @@ func TestFirstOrCreateRowsAffected(t *testing.T) { } } +func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { + type CompositeKeyProduct struct { + ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key + LanguageCode int `gorm:"primaryKey;"` // primary key + Code string + Name string + } + + if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + prod := &CompositeKeyProduct{ + LanguageCode: 56, + Code: "Code56", + Name: "ProductName56", + } + if err := DB.Create(&prod).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + newProd := &CompositeKeyProduct{} + if err := DB.First(&newProd).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name") + } +} + func TestCreateOnConfilctWithDefalutNull(t *testing.T) { type OnConfilctUser struct { ID string From 707d70a542e55f354341e9bd0b925976d24e0a82 Mon Sep 17 00:00:00 2001 From: Saeid Kanishka Date: Fri, 10 Mar 2023 09:51:27 +0100 Subject: [PATCH 127/231] refactor: translate error only when it is not nil (#6133) * refactor: translate error only when it is not nil * refactor: fix the error flow * refactor: update the error if checks * Update gorm.go --------- Co-authored-by: Saeid Saeidee Co-authored-by: Jinzhu --- gorm.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/gorm.go b/gorm.go index b5d98196..9a70c3d2 100644 --- a/gorm.go +++ b/gorm.go @@ -347,14 +347,16 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { - if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { - err = errTranslator.Translate(err) - } + if err != nil { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } - if db.Error == nil { - db.Error = err - } else if err != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, err) + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } } return db.Error } From b62192456fdeb98e67497c97fe3309e135d11fd1 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 10 Mar 2023 17:04:54 +0800 Subject: [PATCH 128/231] fix: diff schema update assign value (#6096) --- callbacks/update.go | 10 +++++++++- tests/update_test.go | 13 +++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index fe6f0994..4eb75788 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -245,11 +245,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } default: updatingSchema := stmt.Schema + var isDiffSchema bool if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} if err := updatingStmt.Parse(stmt.Dest); err == nil { updatingSchema = updatingStmt.Schema + isDiffSchema = true } } @@ -276,7 +278,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) } } } else { diff --git a/tests/update_test.go b/tests/update_test.go index d7634580..b2da11c6 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -773,3 +773,16 @@ func TestUpdateReturning(t *testing.T) { t.Errorf("failed to return updated age column") } } + +func TestUpdateWithDiffSchema(t *testing.T) { + user := GetUser("update-diff-schema-1", Config{}) + DB.Create(&user) + + type UserTemp struct { + Name string + } + + err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error + AssertEqual(t, err, nil) + AssertEqual(t, "update-diff-schema-2", user.Name) +} From 654b5f20066737fd7a7e62662b12bdf9cedba178 Mon Sep 17 00:00:00 2001 From: Jeffry Luqman Date: Fri, 10 Mar 2023 16:11:56 +0700 Subject: [PATCH 129/231] test: pgsql alter column from smallint or string to boolean (#6107) * test: pgsql alter column from smallint to boolean * test: pgsql alter column from string to boolean --- tests/migrate_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 11a0afda..69f86412 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1542,3 +1542,59 @@ func TestMigrateView(t *testing.T) { t.Fatalf("Failed to drop view, got %v", err) } } + +func TestMigrateExistingBoolColumnPG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ColumnStruct struct { + gorm.Model + Name string + StringBool string + SmallintBool int `gorm:"type:smallint"` + } + + type ColumnStruct2 struct { + gorm.Model + Name string + StringBool bool // change existing boolean column from string to boolean + SmallintBool bool // change existing boolean column from smallint or other to boolean + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "string_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + case "smallint_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + } + } + } +} From 8bf1f269cf752cf0a89f086f1a71d29aac75c14c Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 10 Mar 2023 17:21:56 +0800 Subject: [PATCH 130/231] feat: support nested join (#6067) * feat: support nested join * fix: empty rel value --- callbacks/query.go | 186 +++++++++++++++++++++++++++---------------- scan.go | 68 +++++++++++----- tests/joins_test.go | 63 +++++++++++++++ utils/tests/utils.go | 10 ++- utils/utils.go | 17 ++++ 5 files changed, 255 insertions(+), 89 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 9a6d4f4a..c87f17bc 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,6 +8,8 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { @@ -109,86 +111,136 @@ func BuildQuerySQL(db *gorm.DB) { } } + specifiedRelationsName := make(map[string]interface{}) for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name - - columnStmt := gorm.Statement{ - Table: tableAliasName, DB: db, Schema: relation.FieldSchema, - Selects: join.Selects, Omits: join.Omits, - } - - selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) - for _, s := range relation.FieldSchema.DBNames { - if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) - } - } - - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } + + if isNestedJoin { + isRelations = true + relations = gussNestedRelations } } } - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } - if join.On != nil { - onStmt.AddClause(join.On) - } + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) } } - } - } - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: join.JoinType, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + parentTableName = rel.Name + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, diff --git a/scan.go b/scan.go index 12a77862..736db4d3 100644 --- a/scan.go +++ b/scan.go @@ -4,10 +4,10 @@ import ( "database/sql" "database/sql/driver" "reflect" - "strings" "time" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // prepareValues prepare values slice @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}) + joinedNestedSchemaMap := make(map[string]interface{}) for idx, field := range fields { if field == nil { continue } - if len(joinFields) == 0 || joinFields[idx][0] == nil { + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } } + currentReflectValue = relValue + } + + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool @@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { default: var ( fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field + joinFields [][]*schema.Field sch = db.Statement.Schema reflectValue = db.Statement.ReflectValue ) @@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } else { matchedFieldCount[column] = 1 } - } else if names := strings.Split(column, "__"); len(names) > 1 { + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) + joinFields = make([][]*schema.Field, len(columns)) } - joinFields[idx] = [2]*schema.Field{rel.Field, field} + relFields = append(relFields, field) + joinFields[idx] = relFields continue } } diff --git a/tests/joins_test.go b/tests/joins_test.go index 057ad333..e6715bbe 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -325,3 +325,66 @@ func TestJoinArgsWithDB(t *testing.T) { } AssertEqual(t, user4.NamedPet.Name, "") } + +func TestNestedJoins(t *testing.T) { + users := []User{ + { + Name: "nested-joins-1", + Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, + }, + { + Name: "nested-joins-2", + Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, + }, + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB. + Joins("Manager"). + Joins("Manager.Company"). + Joins("Manager.NamedPet"). + Joins("NamedPet"). + Joins("NamedPet.Toy"). + Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 661d727f..49d01f2e 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -13,8 +13,14 @@ import ( func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { - got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + rv := reflect.Indirect(reflect.ValueOf(r)) + ev := reflect.Indirect(reflect.ValueOf(e)) + if rv.IsValid() != ev.IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) + return + } + got := rv.FieldByName(name).Interface() + expect := ev.FieldByName(name).Interface() t.Run(name, func(t *testing.T) { AssertEqual(t, got, expect) }) diff --git a/utils/utils.go b/utils/utils.go index e08533cd..ddbca60a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -131,3 +131,20 @@ func ToString(value interface{}) string { } return "" } + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +} From cc2d46e5be425300e064a39868cfdb333f24e4ac Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Mar 2023 17:42:38 +0800 Subject: [PATCH 131/231] reuse name for savepoints from nested transaction, close #6060 --- finisher_api.go | 17 +++++++++++++++-- tests/go.mod | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index f16d4f43..e6fe4666 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,6 +6,8 @@ import ( "fmt" "reflect" "strings" + "sync" + "sync/atomic" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -608,6 +610,15 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } +var ( + savepointIdx int64 + savepointNamePool = &sync.Pool{ + New: func() interface{} { + return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) + }, + } +) + // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. @@ -617,7 +628,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + poolName := savepointNamePool.Get() + defer savepointNamePool.Put(poolName) + err = db.SavePoint(poolName.(string)).Error if err != nil { return } @@ -625,7 +638,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) + db.RollbackTo(poolName.(string)) } }() } diff --git a/tests/go.mod b/tests/go.mod index e970c9f5..306a530e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -11,10 +11,10 @@ require ( github.com/microsoft/go-mssqldb v0.20.0 // indirect golang.org/x/crypto v0.7.0 // indirect gorm.io/driver/mysql v1.4.7 - gorm.io/driver/postgres v1.4.8 + gorm.io/driver/postgres v1.5.0 gorm.io/driver/sqlite v1.4.4 gorm.io/driver/sqlserver v1.4.2 - gorm.io/gorm v1.24.6 + gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11 ) replace gorm.io/gorm => ../ From d2dd0ce4a73a368a77deb1d5494fe425246fb0e4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 23 Mar 2023 11:18:02 +0800 Subject: [PATCH 132/231] chore(deps): bump actions/setup-go from 3 to 4 (#6165) Bumps [actions/setup-go](https://github.com/actions/setup-go) from 3 to 4. - [Release notes](https://github.com/actions/setup-go/releases) - [Commits](https://github.com/actions/setup-go/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/setup-go dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cfe8e56f..bf225d42 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} @@ -65,7 +65,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} @@ -109,7 +109,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} @@ -152,7 +152,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} @@ -184,7 +184,7 @@ jobs: version: ${{matrix.dbversion}} - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} From 0c7e575f19451921e1124d92847b6cf1a723a724 Mon Sep 17 00:00:00 2001 From: black-06 Date: Thu, 23 Mar 2023 11:18:57 +0800 Subject: [PATCH 133/231] save should be idempotent #6139 (#6149) --- finisher_api.go | 2 +- tests/update_test.go | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index e6fe4666..d647cf64 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -106,7 +106,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { - return tx.Create(value) + return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx diff --git a/tests/update_test.go b/tests/update_test.go index b2da11c6..36ffa6a0 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -610,6 +610,25 @@ func TestUpdateFromSubQuery(t *testing.T) { } } +func TestIdempotentSave(t *testing.T) { + create := Company{ + Name: "company_idempotent", + } + DB.Create(&create) + + var company Company + if err := DB.Find(&company, "id = ?", create.ID).Error; err != nil { + t.Fatalf("failed to find created company, got err: %v", err) + } + + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } +} + func TestSave(t *testing.T) { user := *GetUser("save", Config{}) DB.Create(&user) From 1a7ea98ac51af189177e382a7a083b11a2b9b3c2 Mon Sep 17 00:00:00 2001 From: black-06 Date: Thu, 23 Mar 2023 11:19:53 +0800 Subject: [PATCH 134/231] fix: count with group (#6157) (#6160) * fix: count with group (#6157) * add an easy-to-understand ut --- finisher_api.go | 2 +- tests/count_test.go | 30 ++++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index d647cf64..0e3c2876 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -491,7 +491,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if tx.RowsAffected != 1 { + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } diff --git a/tests/count_test.go b/tests/count_test.go index 2199dc6d..b0dfb0b5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -11,6 +11,32 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestCountWithGroup(t *testing.T) { + DB.Create([]Company{ + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_b"}, + {Name: "company_count_group_c"}, + }) + + var count1 int64 + if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count1 != 1 { + t.Errorf("Count with group should be 1, but got count: %v", count1) + } + + var count2 int64 + if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count2 != 2 { + t.Errorf("Count with group should be 2, but got count: %v", count2) + } +} + func TestCount(t *testing.T) { var ( user1 = *GetUser("count-1", Config{}) @@ -141,8 +167,8 @@ func TestCount(t *testing.T) { } DB.Create(sameUsers) - if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) { - t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + t.Fatalf("Count should be 1, but got count: %v err %v", count11, err) } var count12 int64 From 5d1cdfef2e6c24e71518609e2f668a516abf7284 Mon Sep 17 00:00:00 2001 From: cyhone Date: Thu, 23 Mar 2023 14:02:35 +0800 Subject: [PATCH 135/231] avoid starting a transaction when performing only one insert operation in CreateInBatches function (#6174) --- finisher_api.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 0e3c2876..0e26f181 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -35,9 +35,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { var rowsAffected int64 tx = db.getInstance() + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + callFc := func(tx *DB) error { - // the reflection length judgment of the optimized value - reflectLen := reflectValue.Len() for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize if ends > reflectLen { @@ -55,7 +56,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return nil } - if tx.SkipDefaultTransaction { + if tx.SkipDefaultTransaction || reflectLen <= batchSize { tx.AddError(callFc(tx.Session(&Session{}))) } else { tx.AddError(tx.Transaction(callFc)) From b444011d094db7444f87f442c33860365f55770a Mon Sep 17 00:00:00 2001 From: Saeid Kanishka Date: Fri, 24 Mar 2023 03:07:05 +0100 Subject: [PATCH 136/231] refactor: translatorError flag added for backward compatibility (#6178) Co-authored-by: Saeid Saeidee --- gorm.go | 8 ++++++-- tests/error_translator_test.go | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/gorm.go b/gorm.go index 9a70c3d2..4402a2df 100644 --- a/gorm.go +++ b/gorm.go @@ -47,6 +47,8 @@ type Config struct { QueryFields bool // CreateBatchSize default create batch size CreateBatchSize int + // TranslateError enabling error translation + TranslateError bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -348,8 +350,10 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { if err != nil { - if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { - err = errTranslator.Translate(err) + if db.Config.TranslateError { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } } if db.Error == nil { diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index 2e472e34..ead26fce 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -9,10 +9,20 @@ import ( ) func TestDialectorWithErrorTranslatorSupport(t *testing.T) { + // it shouldn't translate error when the TranslateError flag is false translatedErr := errors.New("translated error") + untranslatedErr := errors.New("some random error") db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) - err := db.AddError(errors.New("some random error")) + err := db.AddError(untranslatedErr) + if errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) + } + + // it should translate error when the TranslateError flag is true + db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true}) + + err = db.AddError(untranslatedErr) if !errors.Is(err, translatedErr) { t.Fatalf("expected err: %v got err: %v", translatedErr, err) } From f0360dccbf699e3bc4fa32c0a4e29bd24b5c47f0 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 11 Apr 2023 10:13:25 +0800 Subject: [PATCH 137/231] fix: embedded should be nil if not exists (#6219) --- schema/field.go | 10 ---------- tests/embedded_struct_test.go | 11 +++++++++++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/schema/field.go b/schema/field.go index 00beb067..15edab93 100644 --- a/schema/field.go +++ b/schema/field.go @@ -580,8 +580,6 @@ func (field *Field) setupValuerAndSetter() { case **bool: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetBool(**data) - } else { - field.ReflectValueOf(ctx, value).SetBool(false) } case bool: field.ReflectValueOf(ctx, value).SetBool(data) @@ -601,8 +599,6 @@ func (field *Field) setupValuerAndSetter() { case **int64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) - } else { - field.ReflectValueOf(ctx, value).SetInt(0) } case int64: field.ReflectValueOf(ctx, value).SetInt(data) @@ -667,8 +663,6 @@ func (field *Field) setupValuerAndSetter() { case **uint64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) - } else { - field.ReflectValueOf(ctx, value).SetUint(0) } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) @@ -721,8 +715,6 @@ func (field *Field) setupValuerAndSetter() { case **float64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) - } else { - field.ReflectValueOf(ctx, value).SetFloat(0) } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) @@ -767,8 +759,6 @@ func (field *Field) setupValuerAndSetter() { case **string: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) - } else { - field.ReflectValueOf(ctx, value).SetString("") } case string: field.ReflectValueOf(ctx, value).SetString(data) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 63ec53ee..0d240fd8 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -103,9 +103,16 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { URL string } + type Author struct { + ID string + Name string + Email string + } + type HNPost struct { *BasePost Upvotes int32 + *Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct } DB.Migrator().DropTable(&HNPost{}) @@ -123,6 +130,10 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Title != "embedded_pointer_type" { t.Errorf("Should find correct value for embedded pointer type") } + + if hnPost.Author != nil { + t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) + } } type Content struct { From 59ca46db3ce53014f1e176ddbc744bfa10da917a Mon Sep 17 00:00:00 2001 From: hanwn <30523763+Hanwn@users.noreply.github.com> Date: Tue, 11 Apr 2023 10:25:47 +0800 Subject: [PATCH 138/231] fix: `limit(0).offset(0)` return all data (#6191) Co-authored-by: hanwang --- clause/limit.go | 2 +- clause/limit_test.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/clause/limit.go b/clause/limit.go index 3ede7385..abda0055 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { + if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index 79065ab6..a9fd4e24 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -28,6 +28,10 @@ func TestLimit(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, "SELECT * FROM `users` LIMIT 0", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET 20", nil, From 1d9f4b0f5578b068210bdd3f31b57b6db92556f2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 10:27:05 +0800 Subject: [PATCH 139/231] chore(deps): bump actions/stale from 7 to 8 (#6190) Bumps [actions/stale](https://github.com/actions/stale) from 7 to 8. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v7...v8) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 77b26abe..fbebfc12 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v7 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 1efa3611..b23a5bf9 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v7 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 43f2f730..c9752883 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v7 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From 05bb9d6106f43fbc115d5e4739fdd8b76a21d792 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Tue, 11 Apr 2023 10:32:46 +0800 Subject: [PATCH 140/231] refactor(migrator): non-standard codes (#6180) --- migrator/index.go | 6 +++--- migrator/migrator.go | 28 +++++++++++++++------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/migrator/index.go b/migrator/index.go index fe686e5a..8845da95 100644 --- a/migrator/index.go +++ b/migrator/index.go @@ -17,12 +17,12 @@ func (idx Index) Table() string { return idx.TableName } -// Name return the name of the index. +// Name return the name of the index. func (idx Index) Name() string { return idx.NameValue } -// Columns return the columns fo the index +// Columns return the columns of the index func (idx Index) Columns() []string { return idx.ColumnList } @@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) { return idx.UniqueValue.Bool, idx.UniqueValue.Valid } -// Option return the optional attribute fo the index +// Option return the optional attribute of the index func (idx Index) Option() string { return idx.OptionValue } diff --git a/migrator/migrator.go b/migrator/migrator.go index 389ce008..32c6a059 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -113,7 +113,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err @@ -123,7 +123,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { parseCheckConstraints = stmt.Schema.ParseCheckConstraints() ) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType for _, columnType := range columnTypes { @@ -135,12 +134,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := execTx.Migrator().AddColumn(value, dbName); err != nil { + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { return err } - } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { - // found, smart migrate - return err } } @@ -195,7 +197,7 @@ func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} @@ -214,7 +216,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," - primaryKeys := []interface{}{} + primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields { primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) } @@ -225,8 +227,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { - if errr == nil { - errr = tx.Migrator().CreateIndex(value, name) + if err == nil { + err = tx.Migrator().CreateIndex(value, name) } }(value, idx.Name) } else { @@ -276,8 +278,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += fmt.Sprint(tableOption) } - errr = tx.Exec(createTableSQL, values...).Error - return errr + err = tx.Exec(createTableSQL, values...).Error + return err }); err != nil { return err } @@ -498,7 +500,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) dv, dvNotNull := columnType.DefaultValue() if dvNotNull && !currentDefaultNotNull { - // defalut value -> null + // default value -> null alterColumn = true } else if !dvNotNull && currentDefaultNotNull { // null -> default value From ccc3cb758a1ca4ccab61ec8572bf5ac1afcaeb5f Mon Sep 17 00:00:00 2001 From: bsmith-auth0 <89545504+bsmith-auth0@users.noreply.github.com> Date: Mon, 10 Apr 2023 20:06:13 -0700 Subject: [PATCH 141/231] fix: many2many association with duplicate belongs to elem (#6206) --- callbacks/associations.go | 27 +++++++++++++++++++------ tests/associations_many2many_test.go | 30 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 9d7c1412..f3cd464a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() != reflect.Struct { break } - if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) } } } if elems.Len() > 0 { - if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 845c16af..b69d668a 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -393,3 +393,33 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) { AssertEqual(t, err, nil) AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") } + +func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { + user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + + user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} From 4b0da0e97a15979820790dd14023f47acc1848d0 Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 11 Apr 2023 12:01:23 +0800 Subject: [PATCH 142/231] fix cond in scopes (#6152) * fix cond in scopes * replace quote * fix execute scopes --- callbacks.go | 6 +----- chainable_api.go | 30 ++++++++++++++++++++++++++ migrator.go | 6 +----- statement.go | 12 ++++++----- tests/scopes_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 15 deletions(-) diff --git a/callbacks.go b/callbacks.go index de979e45..ca6b6d50 100644 --- a/callbacks.go +++ b/callbacks.go @@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } + db = db.executeScopes() } var ( diff --git a/chainable_api.go b/chainable_api.go index a85235e0..19d405cc 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -366,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { return tx } +func (db *DB) executeScopes() (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + return tx +} + // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders diff --git a/migrator.go b/migrator.go index 9c7cc2c4..037afc35 100644 --- a/migrator.go +++ b/migrator.go @@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator { // apply scopes to migrator for len(tx.Statement.scopes) > 0 { - scopes := tx.Statement.scopes - tx.Statement.scopes = nil - for _, scope := range scopes { - tx = scope(tx) - } + tx = tx.executeScopes() } return tx.Dialector.Migrator(tx.Session(&Session{})) diff --git a/statement.go b/statement.go index bc959f0b..59c0b772 100644 --- a/statement.go +++ b/statement.go @@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - for _, scope := range v.Statement.scopes { - v = scope(v) - } + v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else if cs.Expression != nil { + } else { conds = append(conds, cs.Expression) } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } } case map[interface{}]interface{}: for i, j := range v { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea..52c6b37b 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -72,3 +72,54 @@ func TestScopes(t *testing.T) { t.Errorf("select max(id)") } } + +func TestComplexScopes(t *testing.T) { + tests := []struct { + name string + queryFn func(tx *gorm.DB) *gorm.DB + expected string + }{ + { + name: "depth_1", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, + }, { + name: "depth_1_pre_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Where("z = 0").Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, + }, { + name: "depth_2", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, + func(d *gorm.DB) *gorm.DB { + return d. + Or(d.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, + )). + Or("c = 3") + }, + func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) + }) + } +} From 828e22b17fd1ef614f433ee2b8e7be2a4e1c6b1d Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 11 Apr 2023 13:10:38 +0800 Subject: [PATCH 143/231] feat: support embedded preload (#6137) * feat: support embedded preload * fix lint and test * fix test... --- callbacks/preload.go | 93 +++++++++++++++++++++++ callbacks/query.go | 31 +------- schema/field.go | 4 + schema/relationship.go | 52 ++++++++++++- schema/relationship_test.go | 126 ++++++++++++++++++++++++++++++++ schema/schema.go | 48 +++++++++--- schema/schema_helper_test.go | 31 ++++++++ tests/preload_test.go | 138 +++++++++++++++++++++++++++++++++++ 8 files changed, 485 insertions(+), 38 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index ea2570ba..15669c84 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -10,6 +11,98 @@ import ( "gorm.io/gorm/utils" ) +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { + if relationships == nil { + return nil + } + preloadMap := parsePreloadMap(s, preloads) + for name := range preloadMap { + if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { + if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { + return err + } + } else { + return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + } + } + return nil +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/callbacks/query.go b/callbacks/query.go index c87f17bc..95db1f0a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -267,32 +267,7 @@ func Preload(db *gorm.DB) { return } - preloadMap := map[string]map[string][]interface{}{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } - } - } - + preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { preloadNames = append(preloadNames, key) @@ -312,7 +287,9 @@ func Preload(db *gorm.DB) { preloadDB.Statement.Unscoped = db.Statement.Unscoped for _, name := range preloadNames { - if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { + db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) + } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) diff --git a/schema/field.go b/schema/field.go index 15edab93..b5103d53 100644 --- a/schema/field.go +++ b/schema/field.go @@ -89,6 +89,10 @@ type Field struct { NewValuePool FieldNewValuePool } +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") +} + // ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var ( diff --git a/schema/relationship.go b/schema/relationship.go index b33b94a7..e03dcc52 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -27,6 +27,8 @@ type Relationships struct { HasMany []*Relationship Many2Many []*Relationship Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships } type Relationship struct { @@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } if schema.err == nil { - schema.Relationships.Relations[relation.Name] = relation + schema.setRelation(relation) switch relation.Type { case HasOne: schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) @@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // // type User struct { @@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } } + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + return + } + // use same data type for foreign keys if copyableDataType(primaryKeyField.DataType) { relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType @@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu primaryFields = primarySchema.PrimaryFields } + primaryFieldLoop: for _, primaryField := range primaryFields { lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { @@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } for _, name := range lookUpNames { if f := foreignSchema.LookUpField(name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) - break + continue primaryFieldLoop } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 85c45589..732f6f75 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestEmbeddedHas(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + type User struct { + ID int + Cat struct { + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } `gorm:"embedded;embeddedPrefix:cat_"` + Dog struct { + ID int + Name string + UserID int + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) +} + +func TestEmbeddedBelongsTo(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type Address struct { + CountryID int + Country Country + } + type NestedAddress struct { + Address + } + type Org struct { + ID int + PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID int + Address struct { + ID int + Address + } + NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Errorf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "PostalAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "VisitingAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "NestedAddress": { + EmbeddedRelations: map[string]EmbeddedRelations{ + "Address": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + }, + }, + }) +} + func TestVariableRelation(t *testing.T) { var result struct { User diff --git a/schema/schema.go b/schema/schema.go index 17bdb25e..e13a5ed1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -6,6 +6,7 @@ import ( "fmt" "go/ast" "reflect" + "strings" "sync" "gorm.io/gorm/clause" @@ -25,6 +26,7 @@ type Schema struct { PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships @@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + type Tabler interface { TableName() string } @@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema := &Schema{ - Name: modelType.Name(), - ModelType: modelType, - Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, - cacheStore: cacheStore, - namer: namer, - initialized: make(chan struct{}), + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) @@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.DBName = namer.ColumnName(schema.Table, field.Name) } + bindName := field.BindName() if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { @@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { @@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } field.setupValuerAndSetter() } @@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } else { schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field } } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 9abaecba..605aa03a 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { }) } +type EmbeddedRelations struct { + Relations map[string]Relation + EmbeddedRelations map[string]EmbeddedRelations +} + +func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { + for name, relations := range actual { + rs := expected[name] + t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { + if len(relations.Relations) != len(rs.Relations) { + t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) + } + if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { + t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) + } + for n, rel := range relations.Relations { + if r, ok := rs.Relations[n]; !ok { + t.Errorf("failed to find relation by name %s", n) + } else { + checkSchemaRelation(t, &schema.Schema{ + Relationships: schema.Relationships{ + Relations: map[string]*schema.Relationship{n: rel}, + }, + }, r) + } + } + checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) + }) + } +} + func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { diff --git a/tests/preload_test.go b/tests/preload_test.go index e7223b3e..7304e350 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) CheckUserUnscoped(t, *user6, user) } + +func TestEmbedPreload(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type EmbeddedAddress struct { + ID int + Name string + CountryID *int + Country *Country + } + type NestedAddress struct { + EmbeddedAddress + } + type Org struct { + ID int + PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID *int + Address *EmbeddedAddress + NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) + DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) + + org := Org{ + PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, + VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, + Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, + NestedAddress: NestedAddress{ + EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, + }, + } + if err := DB.Create(&org).Error; err != nil { + t.Errorf("failed to create org, got err: %v", err) + } + + tests := []struct { + name string + preloads map[string][]interface{} + expect Org + }{ + { + name: "address country", + preloads: map[string][]interface{}{"Address.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: org.Address, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "postal address country", + preloads: map[string][]interface{}{"PostalAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: org.PostalAddress, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "nested address country", + preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: org.NestedAddress, + }, + }, { + name: "associations", + preloads: map[string][]interface{}{ + clause.Associations: {}, + // clause.Associations won’t preload nested associations + "Address.Country": {}, + }, + expect: org, + }, + } + + DB = DB.Debug() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := Org{} + tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) + for name, args := range test.preloads { + tx = tx.Preload(name, args...) + } + if err := tx.Find(&actual).Error; err != nil { + t.Errorf("failed to find org, got err: %v", err) + } + AssertEqual(t, actual, test.expect) + }) + } +} From e9637024d3780dba4de755e6f5879150f43e8390 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Apr 2023 13:16:25 +0800 Subject: [PATCH 144/231] Update README --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index 0c9ab74e..85ad3050 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Contributors -Thank you for contributing to the GORM framework! - -[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors) +[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! ## License From ac20d9e222400d7ad1963251b4aa2c589afe6901 Mon Sep 17 00:00:00 2001 From: black-06 Date: Fri, 21 Apr 2023 22:09:38 +0800 Subject: [PATCH 145/231] fix: unit test (#6250) * fix: unit test * fix create test https://github.com/go-gorm/gorm/pull/6127#discussion_r1171214125 * style: rename to adaptorSerializerModel --- tests/create_test.go | 3 +++ tests/go.mod | 13 ++++++------- tests/serializer_test.go | 40 ++++++++++++++++++++++++++++++++-------- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/tests/create_test.go b/tests/create_test.go index 75aa8cba..02613b72 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -556,6 +556,9 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { Name string } + if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { t.Fatalf("failed to migrate, got error %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 306a530e..f47d175f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,15 +6,14 @@ require ( github.com/google/uuid v1.3.0 github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.7 + github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.16 // indirect - github.com/microsoft/go-mssqldb v0.20.0 // indirect - golang.org/x/crypto v0.7.0 // indirect - gorm.io/driver/mysql v1.4.7 + golang.org/x/crypto v0.8.0 // indirect + gorm.io/driver/mysql v1.5.0 gorm.io/driver/postgres v1.5.0 - gorm.io/driver/sqlite v1.4.4 - gorm.io/driver/sqlserver v1.4.2 - gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11 + gorm.io/driver/sqlite v1.5.0 + gorm.io/driver/sqlserver v1.4.3 + gorm.io/gorm v1.25.0 ) replace gorm.io/gorm => ../ diff --git a/tests/serializer_test.go b/tests/serializer_test.go index a040a4db..f1b8a336 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -22,12 +22,36 @@ type SerializerStruct struct { Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` - CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type CustomSerializerString string `gorm:"serializer:custom"` EncryptedString EncryptedString } +type SerializerPostgresStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString +} + +func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" } + +func adaptorSerializerModel(s *SerializerStruct) interface{} { + if DB.Dialector.Name() == "postgres" { + sps := SerializerPostgresStruct(*s) + return &sps + } + return s +} + type Roles []string type Job struct { @@ -81,8 +105,8 @@ func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst r func TestSerializer(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) - DB.Migrator().DropTable(&SerializerStruct{}) - if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } @@ -127,8 +151,8 @@ func TestSerializer(t *testing.T) { func TestSerializerZeroValue(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) - DB.Migrator().DropTable(&SerializerStruct{}) - if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } @@ -156,8 +180,8 @@ func TestSerializerZeroValue(t *testing.T) { func TestSerializerAssignFirstOrCreate(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) - DB.Migrator().DropTable(&SerializerStruct{}) - if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } From 32fc2015543c41557a364d45213ca6c710b478bd Mon Sep 17 00:00:00 2001 From: Zhiheng Lin Date: Fri, 21 Apr 2023 22:17:21 +0800 Subject: [PATCH 146/231] fix: avoid coroutine leaks when the dialecter initialization fails. (#6249) Co-authored-by: Kevin Lin --- gorm.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gorm.go b/gorm.go index 4402a2df..07a913fc 100644 --- a/gorm.go +++ b/gorm.go @@ -179,6 +179,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if config.Dialector != nil { err = config.Dialector.Initialize(db) + + if err != nil { + if db, err := db.DB(); err == nil { + _ = db.Close() + } + } } preparedStmt := &PreparedStmtDB{ From 1f763c81cb3ec1c2f2dfada9f42455278e33298c Mon Sep 17 00:00:00 2001 From: yikakia <59830508+yikakia@users.noreply.github.com> Date: Wed, 26 Apr 2023 22:19:06 +0800 Subject: [PATCH 147/231] fix typo chainable_api.go (#6266) --- chainable_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 19d405cc..3dc7256e 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -60,7 +60,7 @@ var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+) // Table specify the table you would like to run db operations // // // Get a user -// db.Table("users").take(&result) +// db.Table("users").Take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { From 407bedae0a529f8512b44522b319aa8434249dee Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 26 Apr 2023 22:19:32 +0800 Subject: [PATCH 148/231] fix: nested joins alias (#6265) --- callbacks/query.go | 7 ++++++- tests/joins_test.go | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 95db1f0a..e89dd199 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -234,7 +234,12 @@ func BuildQuerySQL(db *gorm.DB) { fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) specifiedRelationsName[nestedAlias] = nil } - parentTableName = rel.Name + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ diff --git a/tests/joins_test.go b/tests/joins_test.go index e6715bbe..786fc37e 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -329,8 +329,19 @@ func TestJoinArgsWithDB(t *testing.T) { func TestNestedJoins(t *testing.T) { users := []User{ { - Name: "nested-joins-1", - Manager: GetUser("nested-joins-manager-1", Config{Company: true, NamedPet: true}), + Name: "nested-joins-1", + Manager: &User{ + Name: "nested-joins-manager-1", + Company: Company{ + Name: "nested-joins-manager-company-1", + }, + NamedPet: &Pet{ + Name: "nested-joins-manager-namepet-1", + Toy: Toy{ + Name: "nested-joins-manager-namepet-toy-1", + }, + }, + }, NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, }, { @@ -352,6 +363,7 @@ func TestNestedJoins(t *testing.T) { Joins("Manager"). Joins("Manager.Company"). Joins("Manager.NamedPet"). + Joins("Manager.NamedPet.Toy"). Joins("NamedPet"). Joins("NamedPet.Toy"). Find(&users2, "users.id IN ?", userIDs).Error; err != nil { From aeb298635b04ac7063b545badceeaf77c0eb6ef0 Mon Sep 17 00:00:00 2001 From: hanwn <30523763+Hanwn@users.noreply.github.com> Date: Wed, 26 Apr 2023 22:19:46 +0800 Subject: [PATCH 149/231] debug: use slice Stale sort (#6263) Co-authored-by: hanwang --- callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index ca6b6d50..195d1720 100644 --- a/callbacks.go +++ b/callbacks.go @@ -249,7 +249,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) - sort.Slice(cs, func(i, j int) bool { + sort.SliceStable(cs, func(i, j int) bool { if cs[j].before == "*" && cs[i].before != "*" { return true } From 67642abfff798c25aade7f29c76654ab18e209c4 Mon Sep 17 00:00:00 2001 From: hykuan <33409123+hykuan@users.noreply.github.com> Date: Thu, 4 May 2023 19:29:31 +0800 Subject: [PATCH 150/231] =?UTF-8?q?fix:=20=F0=9F=90=9B=20numeric=20types?= =?UTF-8?q?=20in=20pointer=20embedded=20struct=20test=20failed=20(#6293)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- schema/field.go | 36 +++++++++++++++++++++++++++++++++++ tests/embedded_struct_test.go | 1 + 2 files changed, 37 insertions(+) diff --git a/schema/field.go b/schema/field.go index b5103d53..7d1a1789 100644 --- a/schema/field.go +++ b/schema/field.go @@ -604,6 +604,22 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) } + case **int: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } case int64: field.ReflectValueOf(ctx, value).SetInt(data) case int: @@ -668,6 +684,22 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) } + case **uint: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) case uint: @@ -720,6 +752,10 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) } + case **float32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) + } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) case float32: diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 0d240fd8..3747dad9 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -107,6 +107,7 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { ID string Name string Email string + Age int } type HNPost struct { From 32045fdd7d7a298f09f7ffdca286c3097cfda293 Mon Sep 17 00:00:00 2001 From: black-06 Date: Thu, 4 May 2023 19:30:45 +0800 Subject: [PATCH 151/231] feat: unscoped association (#5899) (#6246) * feat: unscoped association (#5899) * modify name because mysql character is latin1 * work only on has association * format * Unscoped on belongs_to association --- association.go | 63 ++++++++++++++++++++--- tests/associations_belongs_to_test.go | 55 ++++++++++++++++++++ tests/associations_has_many_test.go | 74 +++++++++++++++++++++++++++ 3 files changed, 186 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 6719a1d0..7c93ebea 100644 --- a/association.go +++ b/association.go @@ -14,6 +14,7 @@ import ( type Association struct { DB *DB Relationship *schema.Relationship + Unscope bool Error error } @@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association { return association } +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, + } +} + func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { association.Error = association.buildCondition().Find(out, conds...).Error @@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old associations's foreign key to null - reflectValue := association.DB.Statement.ReflectValue - rel := association.Relationship switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error { association.Error = association.DB.UpdateColumns(updateMap).Error } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } } case schema.Many2Many: var ( @@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.BelongsTo: - tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { @@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error { conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } case schema.HasOne, schema.HasMany: - tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { @@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 99e8aa79..6befb5f2 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -251,3 +251,58 @@ func TestBelongsToDefaultValue(t *testing.T) { err := DB.Create(&user).Error AssertEqual(t, err, nil) } + +func TestBelongsToAssociationUnscoped(t *testing.T) { + type ItemParent struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + } + type ItemChild struct { + gorm.Model + Name string `gorm:"type:varchar(50)"` + ItemParentID uint + ItemParent ItemParent + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) + tx.AutoMigrate(&ItemParent{}, &ItemChild{}) + + item := ItemChild{ + Name: "name", + ItemParent: ItemParent{ + Logo: "logo", + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + tx = tx.Debug() + + // test replace + if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ + Logo: "updated logo", + }); err != nil { + t.Errorf("failed to replace item parent, got error: %v", err) + } + + var parents []ItemParent + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 1 { + t.Errorf("expected %d parents, got %d", 1, len(parents)) + } + + // test delete + if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { + t.Errorf("failed to delete item parent, got error: %v", err) + } + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 0 { + t.Errorf("expected %d parents, got %d", 0, len(parents)) + } +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 002ae636..c31c4b40 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -471,3 +472,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Toys").Clear() AssertAssociationCount(t, users, "Toys", 0, "After Clear") } + +func TestHasManyAssociationUnscoped(t *testing.T) { + type ItemContent struct { + gorm.Model + ItemID uint `gorm:"not null"` + Name string `gorm:"not null;type:varchar(50)"` + LanguageCode string `gorm:"not null;type:varchar(2)"` + } + type Item struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + Contents []ItemContent `gorm:"foreignKey:ItemID"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemContent{}, &Item{}) + tx.AutoMigrate(&ItemContent{}, &Item{}) + + item := Item{ + Logo: "logo", + Contents: []ItemContent{ + {Name: "name", LanguageCode: "en"}, + {Name: "ar name", LanguageCode: "ar"}, + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + // test Replace + if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{ + {Name: "updated name", LanguageCode: "en"}, + {Name: "ar updated name", LanguageCode: "ar"}, + {Name: "le nom", LanguageCode: "fr"}, + }); err != nil { + t.Errorf("failed to replace item content, got error: %v", err) + } + + if count := tx.Model(&item).Association("Contents").Count(); count != 3 { + t.Errorf("expected %d contents, got %d", 3, count) + } + + var contents []ItemContent + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 3 { + t.Errorf("expected %d contents, got %d", 3, len(contents)) + } + + // test delete + if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil { + t.Errorf("failed to delete Contents, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 2 { + t.Errorf("expected %d contents, got %d", 2, count) + } + + // test clear + if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil { + t.Errorf("failed to clear contents association, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 0 { + t.Errorf("expected %d contents, got %d", 0, count) + } + + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 0 { + t.Errorf("expected %d contents, got %d", 0, len(contents)) + } +} From e61b98d69677b8871d832baf2489942d79054a4a Mon Sep 17 00:00:00 2001 From: John Mai Date: Fri, 5 May 2023 15:58:27 +0800 Subject: [PATCH 152/231] feat: migrator support table comment (#6225) * feat: migrator support table comment * feat: migrator support tableType.It like ColumnTypes * Avoid updating the go.mod file. * Update tests_all.sh * Update migrator.go * remove Catalog() & Engine() methods. * remove CatalogValue & EngineValue. --------- Co-authored-by: Jinzhu --- migrator.go | 9 +++++++++ migrator/migrator.go | 5 +++++ migrator/table_type.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 migrator/table_type.go diff --git a/migrator.go b/migrator.go index 037afc35..0e01f567 100644 --- a/migrator.go +++ b/migrator.go @@ -56,6 +56,14 @@ type Index interface { Option() string } +// TableType table type interface +type TableType interface { + Schema() string + Name() string + Type() string + Comment() (comment string, ok bool) +} + // Migrator migrator interface type Migrator interface { // AutoMigrate @@ -72,6 +80,7 @@ type Migrator interface { HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error GetTables() (tableList []string, err error) + TableType(dst interface{}) (TableType, error) // Columns AddColumn(dst interface{}, field string) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 32c6a059..de60f91c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -949,3 +949,8 @@ func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { func (m Migrator) GetTypeAliases(databaseTypeName string) []string { return nil } + +// TableType return tableType gorm.TableType and execErr error +func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { + return nil, errors.New("not support") +} diff --git a/migrator/table_type.go b/migrator/table_type.go new file mode 100644 index 00000000..ed6e42a0 --- /dev/null +++ b/migrator/table_type.go @@ -0,0 +1,33 @@ +package migrator + +import ( + "database/sql" +) + +// TableType table type implements TableType interface +type TableType struct { + SchemaValue string + NameValue string + TypeValue string + CommentValue sql.NullString +} + +// Schema returns the schema of the table. +func (ct TableType) Schema() string { + return ct.SchemaValue +} + +// Name returns the name of the table. +func (ct TableType) Name() string { + return ct.NameValue +} + +// Type returns the type of the table. +func (ct TableType) Type() string { + return ct.TypeValue +} + +// Comment returns the comment of current table. +func (ct TableType) Comment() (comment string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} From 63534145fda9a2ac9ba703650b1a44da6a03e45e Mon Sep 17 00:00:00 2001 From: aclich <71011237+aclich@users.noreply.github.com> Date: Mon, 15 May 2023 09:59:26 +0800 Subject: [PATCH 153/231] =?UTF-8?q?fix:=20=F0=9F=90=9B=20embedded=20struct?= =?UTF-8?q?=20test=20failed=20with=20custom=20datatypes=20(#6311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 🐛 embedded struct test failed with custom datatypes Fix the pointer embedded struct within custom datatypes and *time.time should be nil issue. * fix: 🐛 change test case to avoid mssql driver issue change test cases from bytes to string to avoid mssql driver issue --- schema/field.go | 18 +++----- tests/embedded_struct_test.go | 80 +++++++++++++++++++++++++++++------ 2 files changed, 75 insertions(+), 23 deletions(-) diff --git a/schema/field.go b/schema/field.go index 7d1a1789..dd08e056 100644 --- a/schema/field.go +++ b/schema/field.go @@ -846,7 +846,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -882,14 +882,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -910,14 +908,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 3747dad9..4314f88c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,9 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -104,10 +106,14 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { } type Author struct { - ID string - Name string - Email string - Age int + ID string + Name string + Email string + Age int + Content Content + ContentPtr *Content + Birthday time.Time + BirthdayPtr *time.Time } type HNPost struct { @@ -135,6 +141,48 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Author != nil { t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) } + + now := time.Now().Round(time.Second) + NewPost := HNPost{ + BasePost: &BasePost{Title: "embedded_pointer_type2"}, + Author: &Author{ + Name: "test", + Content: Content{"test"}, + ContentPtr: nil, + Birthday: now, + BirthdayPtr: nil, + }, + } + DB.Create(&NewPost) + + hnPost = HNPost{} + if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != NewPost.Title { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author.Name != NewPost.Author.Name { + t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) + } + + if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { + t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) + } + + if hnPost.Author.ContentPtr != nil { + t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) + } + + if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { + t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) + } + + if hnPost.Author.BirthdayPtr != nil { + t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) + } } type Content struct { @@ -142,18 +190,26 @@ type Content struct { } func (c Content) Value() (driver.Value, error) { - return json.Marshal(c) + // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, + b, err := json.Marshal(c) + return string(b[:]), err } func (c *Content) Scan(src interface{}) error { - b, ok := src.([]byte) - if !ok { - return errors.New("Embedded.Scan byte assertion failed") - } - var value Content - if err := json.Unmarshal(b, &value); err != nil { - return err + str, ok := src.(string) + if !ok { + byt, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + if err := json.Unmarshal(byt, &value); err != nil { + return err + } + } else { + if err := json.Unmarshal([]byte(str), &value); err != nil { + return err + } } *c = value From c3d7d08b9a9f861e53e8b194fcc6b7cedc4191e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 15 May 2023 15:43:44 +0800 Subject: [PATCH 154/231] Clear SET clause after build SQL --- callbacks/update.go | 1 + tests/update_test.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/callbacks/update.go b/callbacks/update.go index 4eb75788..ff075dcf 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") db.Statement.AddClause(set) } else { return diff --git a/tests/update_test.go b/tests/update_test.go index 36ffa6a0..f7c36d74 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -208,13 +208,17 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) - DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) if users[1].Name != "update_column_02_newnew" { t.Errorf("user 2's name should be updated, but got %v", users[1].Name) } + if users[1].Age != 19 { + t.Errorf("user 2's name should be updated, but got %v", users[1].Age) + } + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) var user3 User DB.First(&user3, users[1].ID) From f5837deef3d0c8edc881ca24b992689c71a5cc06 Mon Sep 17 00:00:00 2001 From: 201430098137 <1850396756@qq.com> Date: Wed, 17 May 2023 10:15:41 +0800 Subject: [PATCH 155/231] fix:clickhouse error not capture(#6277) (#6321) Co-authored-by: zhuangg --- finisher_api.go | 1 + 1 file changed, 1 insertion(+) diff --git a/finisher_api.go b/finisher_api.go index 0e26f181..ad14e298 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -533,6 +533,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 + tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } From 6698ba709e1b08d72e6799c0fc5d0cb5a28cce4b Mon Sep 17 00:00:00 2001 From: Avinaba Bhattacharjee Date: Sun, 21 May 2023 18:54:00 +0530 Subject: [PATCH 156/231] renamed License to LICENSE (#6336) --- License => LICENSE | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename License => LICENSE (100%) diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE From 001738be49d26d341f88377d6006356573c9d045 Mon Sep 17 00:00:00 2001 From: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com> Date: Sun, 21 May 2023 18:27:22 +0500 Subject: [PATCH 157/231] Added support of "Violates Foreign Key Constraint" (#6329) * Added support of "Violates Foreign Key Constraint" Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here. * changed the description of ErrForeignKeyViolated --- errors.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/errors.go b/errors.go index 57e3fc5e..cd76f1f5 100644 --- a/errors.go +++ b/errors.go @@ -47,4 +47,6 @@ var ( ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") // ErrDuplicatedKey occurs when there is a unique key constraint violation ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") ) From 8197c00def30911fea0e987d03835fbc6cbbbb59 Mon Sep 17 00:00:00 2001 From: Saeid Date: Thu, 25 May 2023 05:10:00 +0200 Subject: [PATCH 158/231] refactor: error translator test (#6350) Co-authored-by: Saeid Saeidee --- tests/error_translator_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index ead26fce..ca985a09 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -15,8 +15,8 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) err := db.AddError(untranslatedErr) - if errors.Is(err, translatedErr) { - t.Fatalf("expected err: %v got err: %v", translatedErr, err) + if !errors.Is(err, untranslatedErr) { + t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) } // it should translate error when the TranslateError flag is true From 812bb20c34f0a67e6799269ce5ef635e78c0e0cf Mon Sep 17 00:00:00 2001 From: wangliuyang <54885906+wangliuyang520@users.noreply.github.com> Date: Fri, 26 May 2023 10:24:28 +0800 Subject: [PATCH 159/231] fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: add nested transaction and prepareStmt coexist test case note: please test in the MySQL environment Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959 Signed-off-by: 王柳洋 * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements 1. SavetPoint SQL Statement not support in Prepared Statements e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html Change-Id: I082012db9b140e8ec69764c633724665cc802692 Signed-off-by: 王柳洋 * revert(transaction_api): remove savepoint name pool,meaningless Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad Signed-off-by: 王柳洋 --------- Signed-off-by: 王柳洋 Co-authored-by: 王柳洋 --- finisher_api.go | 46 +++++++++++++++++++++++++-------------- tests/transaction_test.go | 13 +++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index ad14e298..6d0b4cd2 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -6,8 +6,6 @@ import ( "fmt" "reflect" "strings" - "sync" - "sync/atomic" "gorm.io/gorm/clause" "gorm.io/gorm/logger" @@ -612,15 +610,6 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -var ( - savepointIdx int64 - savepointNamePool = &sync.Pool{ - New: func() interface{} { - return fmt.Sprintf("gorm_%d", atomic.AddInt64(&savepointIdx, 1)) - }, - } -) - // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. @@ -630,17 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - poolName := savepointNamePool.Get() - defer savepointNamePool.Put(poolName) - err = db.SavePoint(poolName.(string)).Error + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error if err != nil { return } - defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(poolName.(string)) + db.RollbackTo(fmt.Sprintf("sp%p", fc)) } }() } @@ -721,7 +707,21 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } @@ -730,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 5872da94..bfbd8699 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } + + t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) { + // enable prepare statement + tx3 := DB.Session(&gorm.Session{PrepareStmt: true}) + if err := tx3.Transaction(func(tx4 *gorm.DB) error { + // nested transaction + return tx4.Transaction(func(tx5 *gorm.DB) error { + return tx5.First(&User{}, "name = ?", "transaction-2").Error + }) + }); err != nil { + t.Fatalf("prepare statement and nested transcation coexist" + err.Error()) + } + }) } func TestCancelTransaction(t *testing.T) { From 11fdf46a9fcc393e604ea6df22ce0355b2fe1afb Mon Sep 17 00:00:00 2001 From: black-06 Date: Fri, 26 May 2023 10:28:02 +0800 Subject: [PATCH 160/231] fix: save with hook (#6285) (#6294) --- finisher_api.go | 2 +- tests/update_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 6d0b4cd2..f80aa6c0 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -105,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { - return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(value) + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx diff --git a/tests/update_test.go b/tests/update_test.go index f7c36d74..c03d2d47 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -809,3 +809,76 @@ func TestUpdateWithDiffSchema(t *testing.T) { AssertEqual(t, err, nil) AssertEqual(t, "update-diff-schema-2", user.Name) } + +type TokenOwner struct { + ID int + Name string + Token Token `gorm:"foreignKey:UserID"` +} + +func (t *TokenOwner) BeforeSave(tx *gorm.DB) error { + t.Name += "_name" + return nil +} + +type Token struct { + UserID int `gorm:"primary_key"` + Content string `gorm:"type:varchar(100)"` +} + +func (t *Token) BeforeSave(tx *gorm.DB) error { + t.Content += "_encrypted" + return nil +} + +func TestSaveWithHooks(t *testing.T) { + DB.Migrator().DropTable(&Token{}, &TokenOwner{}) + DB.AutoMigrate(&Token{}, &TokenOwner{}) + + saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { + var newOwner TokenOwner + if err := DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + return err + } + if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + return &newOwner, nil + } + + owner := TokenOwner{ + Name: "user", + Token: Token{Content: "token"}, + } + o1, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o1.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name) + } + if o1.Token.Content != "token_encrypted" { + t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content) + } + + owner = TokenOwner{ + ID: owner.ID, + Name: "user", + Token: Token{Content: "token2"}, + } + o2, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o2.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name) + } + if o2.Token.Content != "token2_encrypted" { + t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) + } +} From 26663ab9bf55c603ca69a97504a1dfe7d53a1bb3 Mon Sep 17 00:00:00 2001 From: mohammad ali <2018cs92@student.uet.edu.pk> Date: Mon, 29 May 2023 19:00:48 -0700 Subject: [PATCH 161/231] max identifier length changed to 63 (#6337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * max identifier length changed to 63 * default maxIdentifierLength is 64 * renamed License to LICENSE (#6336) * Added support of "Violates Foreign Key Constraint" (#6329) * Added support of "Violates Foreign Key Constraint" Updated the translator and added the support of "foreign key constraint violation". For this, this error type is needed here. * changed the description of ErrForeignKeyViolated * refactor: error translator test (#6350) Co-authored-by: Saeid Saeidee * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements (#6220) * test: add nested transaction and prepareStmt coexist test case note: please test in the MySQL environment Change-Id: I0db32adc5f74b0d443e98943d3b182236583b959 Signed-off-by: 王柳洋 * fix(nested transaction): SavePoint SQL Statement not support in Prepared Statements 1. SavetPoint SQL Statement not support in Prepared Statements e.g. see mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html Change-Id: I082012db9b140e8ec69764c633724665cc802692 Signed-off-by: 王柳洋 * revert(transaction_api): remove savepoint name pool,meaningless Change-Id: I84aa9924fc54612005a81c83d66fdf8968ee56ad Signed-off-by: 王柳洋 --------- Signed-off-by: 王柳洋 Co-authored-by: 王柳洋 * fix: save with hook (#6285) (#6294) --------- Signed-off-by: 王柳洋 Co-authored-by: Avinaba Bhattacharjee Co-authored-by: Muhammad Amir Ejaz <37077032+codingamir@users.noreply.github.com> Co-authored-by: Saeid Co-authored-by: Saeid Saeidee Co-authored-by: wangliuyang <54885906+wangliuyang520@users.noreply.github.com> Co-authored-by: 王柳洋 Co-authored-by: black-06 --- gorm.go | 2 +- schema/naming.go | 17 +++++++++++------ schema/naming_test.go | 11 ++++++++++- schema/relationship_test.go | 2 +- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/gorm.go b/gorm.go index 07a913fc..84d4b433 100644 --- a/gorm.go +++ b/gorm.go @@ -146,7 +146,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.NamingStrategy == nil { - config.NamingStrategy = schema.NamingStrategy{} + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { diff --git a/schema/naming.go b/schema/naming.go index a258beed..a2a0150a 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -28,10 +28,11 @@ type Replacer interface { // NamingStrategy tables, columns naming strategy type NamingStrategy struct { - TablePrefix string - SingularTable bool - NameReplacer Replacer - NoLowerCase bool + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int } // TableName convert string to table name @@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { prefix, table, name, }, "_"), ".", "_") - if utf8.RuneCountInString(formattedName) > 64 { + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 3f598c33..ab7a5e31 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { } } +func TestFormatNameWithStringLongerThan63Characters(t *testing.T) { + ns := NamingStrategy{IdentifierMaxLength: 63} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { - ns := NamingStrategy{} + ns := NamingStrategy{IdentifierMaxLength: 64} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 732f6f75..1eb66bb4 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -768,7 +768,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { s, err := schema.Parse( &Book{}, &sync.Map{}, - schema.NamingStrategy{}, + schema.NamingStrategy{IdentifierMaxLength: 64}, ) if err != nil { t.Fatalf("Failed to parse schema") From 740f2be4535f7156752a5e6ce4bc94db59948a10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=9C=E6=96=B9=E4=B8=8A=E4=BA=BA?= Date: Wed, 31 May 2023 19:21:51 +0800 Subject: [PATCH 162/231] fix: begin transaction fail, rollback panic (#6365) --- prepare_stmt.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index e09fe814..4b3551c6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "reflect" "sync" ) @@ -163,14 +164,14 @@ type PreparedStmtTX struct { } func (tx *PreparedStmtTX) Commit() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction From c1ea73036715018a1bb55cdb8690441044e13a76 Mon Sep 17 00:00:00 2001 From: black Date: Thu, 1 Jun 2023 10:54:36 +0800 Subject: [PATCH 163/231] fix: avoid panic when open fails --- gorm.go | 2 +- tests/gorm_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 84d4b433..46d1843d 100644 --- a/gorm.go +++ b/gorm.go @@ -181,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 9827465c..4c31b88b 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -3,9 +3,19 @@ package tests_test import ( "testing" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) +func TestOpen(t *testing.T) { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc + _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err == nil { + t.Fatalf("should returns error but got nil") + } +} + func TestReturningWithNullToZeroValues(t *testing.T) { dialect := DB.Dialector.Name() switch dialect { From 7a76c042e67a22e226179864dd0bdfedce0d53b9 Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Mon, 5 Jun 2023 11:23:17 +0300 Subject: [PATCH 164/231] refactor: remove unnecessary prepared statement allocation (#6374) --- gorm.go | 48 ++++++++++++++++++++++++------------------------ prepare_stmt.go | 9 +++++++++ 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/gorm.go b/gorm.go index 46d1843d..ecdb700b 100644 --- a/gorm.go +++ b/gorm.go @@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), - } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) - if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) - switch t := tx.Statement.ConnPool.(type) { - case Tx: - tx.Statement.ConnPool = &PreparedStmtTX{ - Tx: t, - PreparedStmtDB: preparedStmt, - } - default: - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, - } - } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { diff --git a/prepare_stmt.go b/prepare_stmt.go index 4b3551c6..10fefc31 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -21,6 +21,15 @@ type PreparedStmtDB struct { ConnPool } +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } +} + func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() From 5eaccaa624773441da797f93950117a44863df12 Mon Sep 17 00:00:00 2001 From: KantaHasegawa <66783124+KantaHasegawa@users.noreply.github.com> Date: Mon, 5 Jun 2023 17:24:00 +0900 Subject: [PATCH 165/231] reafactor: add nil detection when sqldb return (#6373) * reafactor: add null detection when sqldb return * refactor: Detecting nil in dbConnector.GetDBConn() * refactor: Revert partial code from c1ea73036715018a1bb55cdb8690441044e13a76 * fix: fix if statement --- gorm.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gorm.go b/gorm.go index ecdb700b..21b289db 100644 --- a/gorm.go +++ b/gorm.go @@ -181,7 +181,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, _ := db.DB(); db != nil { + if db, err := db.DB(); err == nil { _ = db.Close() } } @@ -376,10 +376,12 @@ func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } } - if sqldb, ok := connPool.(*sql.DB); ok { + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } From 661781a3d7a36b33d8a30d97499d7b428939a9ec Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Mon, 5 Jun 2023 11:25:05 +0300 Subject: [PATCH 166/231] feat: add *sql.DB connector that uses database context (#6366) * feat: add SQLConnector * rename --- gorm.go | 4 ++++ interfaces.go | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/gorm.go b/gorm.go index 21b289db..9297850e 100644 --- a/gorm.go +++ b/gorm.go @@ -375,6 +375,10 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool + if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(db) + } + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { return sqldb, err diff --git a/interfaces.go b/interfaces.go index 3bcc3d57..1950d740 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,6 +77,12 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } +// GetDBConnectorWithContext represents SQL db connector which takes into +// account the current database context +type GetDBConnectorWithContext interface { + GetDBConnWithContext(db *DB) (*sql.DB, error) +} + // Rows rows interface type Rows interface { Columns() ([]string, error) From 7157b7e375c58851da2245c9f4a2b496f3b0ae71 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 7 Jun 2023 08:02:07 +0100 Subject: [PATCH 167/231] fix: database/sql.Scanner should not retain references (#6380) --- tests/scanner_valuer_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 14121699..472434b4 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error { return errors.New("Too short") } - *data = b[3:] + *data = append((*data)[0:], b[3:]...) return nil } else if s, ok := value.(string); ok { - *data = []byte(s)[3:] + *data = []byte(s[3:]) return nil } From 7dd702d379bca5e28a2b65278b5a373f326c72ca Mon Sep 17 00:00:00 2001 From: Johannes Riecken Date: Wed, 7 Jun 2023 09:02:30 +0200 Subject: [PATCH 168/231] Fix incorrect documentation comment (has many -> has one) (#6382) --- utils/tests/models.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/tests/models.go b/utils/tests/models.go index ec1651a3..a4bad2fc 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,7 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) -// NamedPet is a reference to a Named `Pets` (has many) +// NamedPet is a reference to a named `Pet` (has one) type User struct { gorm.Model Name string From c2d571cbc8ba7925a7438181cc3dcca10a89f81f Mon Sep 17 00:00:00 2001 From: Saeid Date: Sat, 10 Jun 2023 15:05:19 +0200 Subject: [PATCH 169/231] test: coverage for duplicated key err (#6389) * test: ErrDuplicatedKey coverage added * test: updated sqlserver version * test: removed sqlserver * test: support added for sqlserver --------- Co-authored-by: Saeid Saeidee --- tests/associations_many2many_test.go | 2 +- tests/error_translator_test.go | 31 ++++++++++++++++++++++++++++ tests/go.mod | 5 ++--- tests/prepared_stmt_test.go | 4 ++-- tests/tests_test.go | 14 ++++++------- tests/transaction_test.go | 2 +- 6 files changed, 44 insertions(+), 14 deletions(-) diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index b69d668a..39410aed 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -358,7 +358,7 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) { } func TestConcurrentMany2ManyAssociation(t *testing.T) { - db, err := OpenTestConnection() + db, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("open test connection failed, err: %+v", err) } diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index ca985a09..f6c70677 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -27,3 +27,34 @@ func TestDialectorWithErrorTranslatorSupport(t *testing.T) { t.Fatalf("expected err: %v got err: %v", translatedErr, err) } } + +func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + if err = db.AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities table, got error: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if err != nil { + t.Fatalf("failed to create record: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index f47d175f..0b38b9d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,12 +8,11 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.16 // indirect - golang.org/x/crypto v0.8.0 // indirect gorm.io/driver/mysql v1.5.0 gorm.io/driver/postgres v1.5.0 gorm.io/driver/sqlite v1.5.0 - gorm.io/driver/sqlserver v1.4.3 - gorm.io/gorm v1.25.0 + gorm.io/driver/sqlserver v1.5.1 + gorm.io/gorm v1.25.1 ) replace gorm.io/gorm => ../ diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 64baa01b..b234c8bf 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -92,7 +92,7 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } func TestPreparedStmtDeadlock(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() @@ -127,7 +127,7 @@ func TestPreparedStmtDeadlock(t *testing.T) { } func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection() + tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() diff --git a/tests/tests_test.go b/tests/tests_test.go index 90eb847f..0167d406 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -26,7 +26,7 @@ var ( func init() { var err error - if DB, err = OpenTestConnection(); err != nil { + if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { @@ -49,7 +49,7 @@ func init() { } } -func OpenTestConnection() (db *gorm.DB, err error) { +func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": @@ -57,7 +57,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = mysqlDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) case "postgres": log.Println("testing postgres...") if dbDSN == "" { @@ -66,7 +66,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, - }), &gorm.Config{}) + }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 @@ -80,16 +80,16 @@ func OpenTestConnection() (db *gorm.DB, err error) { if dbDSN == "" { dbDSN = sqlserverDSN } - db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) case "tidb": log.Println("testing tidb...") if dbDSN == "" { dbDSN = tidbDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) } if err != nil { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index bfbd8699..126ccb23 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -361,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) { } func TestTransactionOnClosedConn(t *testing.T) { - DB, err := OpenTestConnection() + DB, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } From 206613868439c5ee7e62e116a46503eddf55a548 Mon Sep 17 00:00:00 2001 From: Saeid Date: Sun, 11 Jun 2023 01:42:18 +0200 Subject: [PATCH 170/231] ci: fix mariadb mysqladmin (#6401) Co-authored-by: Saeid Saeidee --- .github/workflows/tests.yml | 46 +++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf225d42..1191a8ea 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,7 +41,7 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + dbversion: ['mysql:latest', 'mysql:5.7'] go: ['1.19', '1.18'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -72,7 +72,6 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v3 - - name: go mod package cache uses: actions/cache@v3 with: @@ -82,6 +81,49 @@ jobs: - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + mariadb: + strategy: + matrix: + dbversion: [ 'mariadb:latest' ] + go: [ '1.19', '1.18' ] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mariadb-admin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: go mod package cache + uses: actions/cache@v3 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + postgres: strategy: matrix: From c10f807d3c0b2b3204c80d62cf7650e21d7e3316 Mon Sep 17 00:00:00 2001 From: Saeid Date: Wed, 12 Jul 2023 16:21:22 +0300 Subject: [PATCH 171/231] test: coverage for foreign key violation err (#6403) * test: coverage for foreign key violation err * test: enabled foreign keys constraint for sqlite * test: enabled mysql& mssql for ErrForeignKeyViolate * test: disabled mysql & updated sqlserver driver version * test: skipped tidb --------- Co-authored-by: Saeid Saeidee --- tests/error_translator_test.go | 51 ++++++++++++++++++++++++++++++++++ tests/go.mod | 14 ++++------ tests/tests_test.go | 2 +- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go index f6c70677..ee54300e 100644 --- a/tests/error_translator_test.go +++ b/tests/error_translator_test.go @@ -44,6 +44,8 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { return } + DB.Migrator().DropTable(&City{}) + if err = db.AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities table, got error: %v", err) } @@ -58,3 +60,52 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) } } + +func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + type Museum struct { + gorm.Model + Name string `gorm:"unique"` + CityID uint + City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}, &Museum{}) + + if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { + t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) + } + + city := City{Name: "Amsterdam"} + + err = db.Create(&city).Error + if err != nil { + t.Fatalf("failed to create city: %v", err) + } + + err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error + if err != nil { + t.Fatalf("failed to create museum: %v", err) + } + + err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error + if !errors.Is(err, gorm.ErrForeignKeyViolated) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index 0b38b9d0..aebe5a06 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,15 +4,13 @@ go 1.16 require ( github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.8 - github.com/mattn/go-sqlite3 v1.14.16 // indirect - gorm.io/driver/mysql v1.5.0 - gorm.io/driver/postgres v1.5.0 - gorm.io/driver/sqlite v1.5.0 - gorm.io/driver/sqlserver v1.5.1 - gorm.io/gorm v1.25.1 + github.com/lib/pq v1.10.9 + gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 + gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 + gorm.io/driver/sqlite v1.5.2 + gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a + gorm.io/gorm v1.25.2-0.20230610234218-206613868439 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_test.go b/tests/tests_test.go index 0167d406..47c2a7c1 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -89,7 +89,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) } if err != nil { From a7f01bd1b22ec7131c420de62abe5f7e85573277 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Jul 2023 10:47:19 +0800 Subject: [PATCH 172/231] Test Pluck with customized type --- tests/go.mod | 18 ++++++++++++++++-- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index aebe5a06..147d0a79 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.16 +go 1.18 require ( github.com/google/uuid v1.3.0 @@ -10,7 +10,21 @@ require ( gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 gorm.io/driver/sqlite v1.5.2 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.2-0.20230610234218-206613868439 + gorm.io/gorm v1.25.2 +) + +require ( + github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/microsoft/go-mssqldb v1.4.0 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/text v0.11.0 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/query_test.go b/tests/query_test.go index b6bd0736..5728378d 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "database/sql/driver" "fmt" "reflect" "regexp" @@ -658,6 +659,18 @@ func TestOrWithAllFields(t *testing.T) { } } +type Int64 int64 + +func (v Int64) Value() (driver.Value, error) { + return v - 1, nil +} + +func (f *Int64) Scan(v interface{}) error { + y := v.(int64) + *f = Int64(y + 1) + return nil +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -685,6 +698,11 @@ func TestPluck(t *testing.T) { t.Errorf("got error when pluck id: %v", err) } + var ids2 []Int64 + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil { + t.Errorf("got error when pluck id: %v", err) + } + for idx, name := range names { if name != users[idx].Name { t.Errorf("Unexpected result on pluck name, got %+v", names) @@ -697,6 +715,12 @@ func TestPluck(t *testing.T) { } } + for idx, id := range ids2 { + if int(id) != int(users[idx].ID+1) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } + var times []time.Time if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { t.Errorf("got error when pluck time: %v", err) From 1fb26ac90e1959a3fb08a4878c00f26ca5284604 Mon Sep 17 00:00:00 2001 From: Saeid Date: Fri, 4 Aug 2023 04:30:07 +0200 Subject: [PATCH 173/231] test: coverage for tabletype added (#6496) * test: coverage for tabletype added * test: tidb exclueded --------- Co-authored-by: Saeid Saeidee --- tests/helper_test.go | 4 ++++ tests/migrate_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/helper_test.go b/tests/helper_test.go index c34e357c..1a4874ee 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -265,6 +265,10 @@ func isTiDB() bool { return os.Getenv("GORM_DIALECT") == "tidb" } +func isMysql() bool { + return os.Getenv("GORM_DIALECT") == "mysql" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 69f86412..849e2b7b 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1598,3 +1598,48 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { } } } + +func TestTableType(t *testing.T) { + // currently it is only supported for mysql driver + if !isMysql() { + return + } + + const tblName = "cities" + const tblSchema = "gorm" + const tblType = "BASE TABLE" + const tblComment = "foobar comment" + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&City{}) + + if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities tables, got error: %v", err) + } + + tableType, err := DB.Table("cities").Migrator().TableType(&City{}) + if err != nil { + t.Fatalf("failed to get table type, got error %v", err) + } + + if tableType.Schema() != tblSchema { + t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema()) + } + + if tableType.Name() != tblName { + t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name()) + } + + if tableType.Type() != tblType { + t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type()) + } + + comment, ok := tableType.Comment() + if !ok || comment != tblComment { + t.Fatalf("expected comment %s got %s", tblComment, comment) + } +} From 193c454cf48f3e65a98abc0a09e04fe6f5d49c0a Mon Sep 17 00:00:00 2001 From: San Ye Date: Fri, 4 Aug 2023 10:31:18 +0800 Subject: [PATCH 174/231] keep float precision in ExplainSQL (#6495) --- logger/sql.go | 6 ++++-- logger/sql_test.go | 21 ++++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index bcacc7cf..1521c1fd 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -93,8 +93,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: vars[idx] = utils.ToString(v) - case float64, float32: - vars[idx] = fmt.Sprintf("%.6f", v) + case float32: + vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: diff --git a/logger/sql_test.go b/logger/sql_test.go index c5b181a9..e4a72748 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -57,44 +57,51 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + } for idx, r := range results { From f47376181317ccbf08dcdb28b0c0171dc5d61fda Mon Sep 17 00:00:00 2001 From: Aayush Acharya <33954116+aayushacharya@users.noreply.github.com> Date: Fri, 4 Aug 2023 08:20:59 +0545 Subject: [PATCH 175/231] fix: added `SkipHooks` in db `getInstance()` (#6484) --- gorm.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gorm.go b/gorm.go index 9297850e..203527af 100644 --- a/gorm.go +++ b/gorm.go @@ -399,11 +399,12 @@ func (db *DB) getInstance() *DB { if db.clone == 1 { // clone with new statement tx.Statement = &Statement{ - DB: tx, - ConnPool: db.Statement.ConnPool, - Context: db.Statement.Context, - Clauses: map[string]clause.Clause{}, - Vars: make([]interface{}, 0, 8), + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + SkipHooks: db.Statement.SkipHooks, } } else { // with clone statement From 3c34bc2f59fd080dce5e7a829a8f178b3f4de194 Mon Sep 17 00:00:00 2001 From: fayvori <80601865+fayvori@users.noreply.github.com> Date: Mon, 7 Aug 2023 11:35:19 +0300 Subject: [PATCH 176/231] refactor: Regex description (#6507) * Mirror cleanup * Regex description --------- Co-authored-by: Ignat Belousov --- logger/sql.go | 2 ++ logger/sql_test.go | 1 - migrator/migrator.go | 9 +++++++++ tests/go.mod | 6 +++--- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 1521c1fd..13e5d957 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -28,8 +28,10 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability diff --git a/logger/sql_test.go b/logger/sql_test.go index e4a72748..d9afe393 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,7 +101,6 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, - } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index de60f91c..b15a43ef 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,8 +16,17 @@ import ( "gorm.io/gorm/schema" ) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) +// TODO:? Create const vars for raw sql queries ? + // Migrator m struct type Migrator struct { Config diff --git a/tests/go.mod b/tests/go.mod index 147d0a79..7a89ee05 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.4.0 // indirect - golang.org/x/crypto v0.11.0 // indirect - golang.org/x/text v0.11.0 // indirect + github.com/microsoft/go-mssqldb v1.5.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/text v0.12.0 // indirect ) replace gorm.io/gorm => ../ From 15162afaf2a1cd1ee8c63ebc0dc14b8baa0613f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Aug 2023 13:30:48 +0800 Subject: [PATCH 177/231] Support GetDBConnWithContext PreparedStmtDB --- prepare_stmt.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 10fefc31..9d98c86e 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() - } - +func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(gormdb) + } + + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { From bae684b3639dff3e35d0ed330bc82c12e8282110 Mon Sep 17 00:00:00 2001 From: weih Date: Thu, 10 Aug 2023 13:34:33 +0800 Subject: [PATCH 178/231] fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) --- clause/expression.go | 16 ++++++++++------ clause/expression_test.go | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 92ac7f22..8d010522 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/clause/expression_test.go b/clause/expression_test.go index aaede61c..b997bf11 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -199,6 +199,11 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{}}, + }, + Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, From fef42941ba87bff8dad5d48b057a2c2056345984 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:31 +0800 Subject: [PATCH 179/231] feat: rm GetDBConnWithContext method (#6535) * feat: rm contextconnpool method * feat: nil --- go.sum | 2 -- gorm.go | 9 ++++++--- interfaces.go | 6 ------ prepare_stmt.go | 23 ++++++++++++++++++----- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/go.sum b/go.sum index fb4240eb..bd6104c9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 203527af..32193870 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/interfaces.go b/interfaces.go index 1950d740..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/prepare_stmt.go b/prepare_stmt.go index 9d98c86e..aa944624 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } - if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(gormdb) - } - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } @@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -176,6 +185,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() From 2c2089760c5a35b3884c7a949621ce0e790e7835 Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:57 +0800 Subject: [PATCH 180/231] add float32 test case (#6530) --- logger/sql_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/logger/sql_test.go b/logger/sql_test.go index d9afe393..a82fa546 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,6 +101,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 7e44f73ad3b657a86bbdc881787b03c25ab789a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=9A=E4=B8=80=E6=B6=9B?= Date: Sat, 19 Aug 2023 21:35:14 +0800 Subject: [PATCH 181/231] fix schema GetIdentityFieldValuesMap interface or ptr (#6417) Co-authored-by: uptutu --- schema/utils.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/schema/utils.go b/schema/utils.go index 65d012e5..7fdda185 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} From ac07543962994da4c6994ba3907417d7835a2619 Mon Sep 17 00:00:00 2001 From: Rataj Date: Sun, 20 Aug 2023 13:46:56 +0200 Subject: [PATCH 182/231] Fixed error message when dialector fails to initialize (#6509) Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides even though "db" doesn't exist. Previously, this code leads to: panic: runtime error: invalid memory address or nil pointer dereference This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case: [error] failed to initialize database, got error default addr for network 'localhost' unknown --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 32193870..775cd3de 100644 --- a/gorm.go +++ b/gorm.go @@ -182,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } From 653732e1c33858f5743a34f9fbfe66428d041760 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Aug 2023 20:19:29 +0800 Subject: [PATCH 183/231] Update go testing versions --- .github/workflows/tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1191a8ea..e98a17d6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7'] - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -85,7 +85,7 @@ jobs: strategy: matrix: dbversion: [ 'mariadb:latest' ] - go: [ '1.19', '1.18' ] + go: ['1.21', '1.20', '1.19'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -170,7 +170,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} @@ -214,7 +214,7 @@ jobs: strategy: matrix: dbversion: [ 'v6.5.0' ] - go: [ '1.19', '1.18' ] + go: ['1.21', '1.20', '1.19'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} From e57e5d8884d801caa4ce0307bcd081f7e889e514 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Aug 2023 15:40:54 +0800 Subject: [PATCH 184/231] Update go.mod --- go.mod | 2 +- tests/go.mod | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 85e4242a..deb61b74 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.16 +go 1.18 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/tests/go.mod b/tests/go.mod index 7a89ee05..aef26e3e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,14 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.3.1 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 - gorm.io/driver/sqlite v1.5.2 + gorm.io/driver/sqlite v1.5.3 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.2 + gorm.io/gorm v1.25.4 ) require ( @@ -19,7 +19,7 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.2 // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/microsoft/go-mssqldb v1.5.0 // indirect From 2095d42b4c15de8d0cdaf64fd75e306bec40d9c4 Mon Sep 17 00:00:00 2001 From: Samuel N Cui Date: Mon, 9 Oct 2023 17:26:27 +0800 Subject: [PATCH 185/231] fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type (#6624) * fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type fix #4760 * feat: add auto increment test * feat: update sqlite * feat: update tests deps sqlite to v1.5.4 --- migrator/migrator.go | 2 +- tests/go.mod | 8 ++++---- tests/migrate_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index b15a43ef..49bc9371 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -217,7 +217,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } diff --git a/tests/go.mod b/tests/go.mod index aef26e3e..5a0aeddd 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.10.9 gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 - gorm.io/driver/sqlite v1.5.3 + gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a gorm.io/gorm v1.25.4 ) @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.5.0 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/text v0.12.0 // indirect + github.com/microsoft/go-mssqldb v1.6.0 // indirect + golang.org/x/crypto v0.14.0 // indirect + golang.org/x/text v0.13.0 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 849e2b7b..cfd3e0ac 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -862,6 +862,48 @@ func TestMigrateWithSpecialName(t *testing.T) { AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } +// https://github.com/go-gorm/gorm/issues/4760 +func TestMigrateAutoIncrement(t *testing.T) { + type AutoIncrementStruct struct { + ID int64 `gorm:"primarykey;autoIncrement"` + Field1 uint32 `gorm:"column:field1"` + Field2 float32 `gorm:"column:field2"` + } + + if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil { + t.Fatalf("AutoMigrate err: %v", err) + } + + const ROWS = 10 + for idx := 0; idx < ROWS; idx++ { + if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + } + + rows := make([]*AutoIncrementStruct, 0, ROWS) + if err := DB.Order("id ASC").Find(&rows).Error; err != nil { + t.Fatalf("find auto_increment_struct fail, err: %v", err) + } + + ids := make([]int64, 0, len(rows)) + for _, row := range rows { + ids = append(ids, row.ID) + } + lastID := ids[len(ids)-1] + + if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("delete auto_increment_struct fail, err: %v", err) + } + + newRow := &AutoIncrementStruct{} + if err := DB.Create(newRow).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + + AssertEqual(t, newRow.ID, lastID+1) +} + // https://github.com/go-gorm/gorm/issues/5320 func TestPrimarykeyID(t *testing.T) { if DB.Dialector.Name() != "postgres" { From 9d8a5bb208f5616638cbaad878a12d5ac73970d3 Mon Sep 17 00:00:00 2001 From: "hjwblog.com" Date: Tue, 10 Oct 2023 14:45:48 +0800 Subject: [PATCH 186/231] feat: reuse name (#6626) --- clause/expression.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clause/expression.go b/clause/expression.go index 8d010522..3140846e 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true - name = []byte{} + name = name[:0] } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { From 12ba285a52fb25c3422e16226666ba791f376c0b Mon Sep 17 00:00:00 2001 From: Mathias Zeller <62462901+matoubidou@users.noreply.github.com> Date: Tue, 10 Oct 2023 08:46:32 +0200 Subject: [PATCH 187/231] *datatypes.JSON in model causes panic on tx.Statement.Changed (#6611) * do not panic on nil * more explanation in comments * get things compact --- utils/utils.go | 33 +++++++++++++++++++++------------ utils/utils_test.go | 1 + 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index ddbca60a..c8fec5b0 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } - - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } - - return reflect.DeepEqual(src, dst) +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true } - return true + if x == nil || y == nil { + return false + } + + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } + + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() + } + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { diff --git a/utils/utils_test.go b/utils/utils_test.go index 71eef964..d0486822 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -98,6 +98,7 @@ func TestAssertEqual(t *testing.T) { {"error not equal", errors.New("1"), errors.New("2"), false}, {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + {"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, } for _, test := range assertEqualTests { t.Run(test.name, func(t *testing.T) { From 8c18714462de07fa3392b99eda089f2f9e3b6042 Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Mon, 9 Oct 2023 23:50:29 -0700 Subject: [PATCH 188/231] Don't call MethodByName with a variable arg (#6602) Go 1.22 goes somewhat toward addressing the issue using reflect MethodByName disabling linker deadcode elimination (DCE) and the resultant large increase in binary size because the linker cannot prune unused code because it might be reached via reflection. Go Issue golang/go#62257 reduces the number of incidences of this problem by leveraging a compiler assist to avoid marking functions containing calls to MethodByName as ReflectMethods as long as the arguments are constants. An analysis of Uber Technologies code base however shows that a number of transitive imports still contain calls to MethodByName with a variable argument, including GORM. In the case of GORM, the solution we are proposing is because the number of possible methods is finite, we will "unroll" this. This demonstrably shows that GORM is not longer a problem for DCE. Before ``` % go version go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64 % go test ./... -ldflags=-dumpdep 2> >(grep -i -e '->.*') gorm.io/gorm.(*Statement).BuildCondition -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ok gorm.io/gorm (cached) ok gorm.io/gorm/callbacks (cached) gorm.io/gorm/clause_test.BenchmarkComplexSelect -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ? gorm.io/gorm/migrator [no test files] ok gorm.io/gorm/clause (cached) ok gorm.io/gorm/logger (cached) gorm.io/gorm/schema_test.TestAdvancedDataTypeValuerAndSetter -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ? gorm.io/gorm/utils/tests [no test files] ok gorm.io/gorm/schema (cached) ok gorm.io/gorm/utils (cached) ``` After ``` %go version go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64 %go test ./... -ldflags=-dumpdep 2> >(grep -i -e '->.*') ok gorm.io/gorm (cached) ok gorm.io/gorm/callbacks (cached) ? gorm.io/gorm/migrator [no test files] ? gorm.io/gorm/utils/tests [no test files] ok gorm.io/gorm/clause (cached) ok gorm.io/gorm/logger (cached) ok gorm.io/gorm/schema (cached) ok gorm.io/gorm/utils (cached) ``` --- schema/schema.go | 63 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index e13a5ed1..3e7459ce 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -13,6 +13,20 @@ import ( "gorm.io/gorm/logger" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") @@ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} - for _, name := range callbacks { - if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + callbackTypes := []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, + } + for _, cbName := range callbackTypes { + if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } @@ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { From 1b240810106fd68f84cfe73bcacaf91a8e4ce1dd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:50:45 +0800 Subject: [PATCH 189/231] chore(deps): bump actions/checkout from 3 to 4 (#6586) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- .github/workflows/reviewdog.yml | 2 +- .github/workflows/tests.yml | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0e8aaa60..ef852765 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index a6542d57..3a65f0bc 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e98a17d6..380231b9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -70,7 +70,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -113,7 +113,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -156,7 +156,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -199,7 +199,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -231,7 +231,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache From 6bef318891b98263f3568c13093b5860245d2c52 Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Tue, 10 Oct 2023 09:03:34 +0200 Subject: [PATCH 190/231] add support for returning in sqlserver (#6585) --- tests/delete_test.go | 6 +++--- tests/update_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/delete_test.go b/tests/delete_test.go index 5cb4b91e..5d112b4e 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestSoftDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } @@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) { } func TestDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } diff --git a/tests/update_test.go b/tests/update_test.go index c03d2d47..a3fb7015 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestUpdateReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } From 78e905919fc253332fb032d0f4a76e7753e437e4 Mon Sep 17 00:00:00 2001 From: gleb <47985861+glebarez@users.noreply.github.com> Date: Thu, 26 Oct 2023 06:54:15 +0300 Subject: [PATCH 191/231] tests/sqilte: enable FOREIGN_KEYS inside OpenTestConnection (#6641) --- tests/tests_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_test.go b/tests/tests_test.go index 47c2a7c1..f9c6cab5 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -43,9 +43,6 @@ func init() { } RunMigrations() - if DB.Dialector.Name() == "sqlite" { - DB.Exec("PRAGMA foreign_keys = ON") - } } } @@ -89,7 +86,10 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) + if err == nil { + db.Exec("PRAGMA foreign_keys = ON") + } } if err != nil { From 5adc0ce5f6c8cf97f1f6b9e835750406612c2fe0 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 26 Oct 2023 11:58:13 +0800 Subject: [PATCH 192/231] test: fix TestEmbeddedRelations (#6639) --- tests/embedded_struct_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 4314f88c..873bba2a 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -236,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) { } func TestEmbeddedRelations(t *testing.T) { + type EmbUser struct { + gorm.Model + Name string + Age uint + Languages []Language `gorm:"many2many:EmbUserSpeak;"` + } + type AdvancedUser struct { - User `gorm:"embedded"` + EmbUser `gorm:"embedded"` Advanced bool } From 9fea15ae75fb9ff2bd86dcaa167673c8ed77394f Mon Sep 17 00:00:00 2001 From: black-06 Date: Mon, 30 Oct 2023 17:15:49 +0800 Subject: [PATCH 193/231] feat: add MigrateColumnUnique (#6640) * feat: add MigrateColumnUnique * feat: define new methods * delete debug in test --- migrator.go | 2 ++ migrator/migrator.go | 22 ++++++++++++++++++++++ schema/naming.go | 8 ++++++++ tests/associations_belongs_to_test.go | 2 -- tests/count_test.go | 2 +- tests/preload_test.go | 1 - tests/update_test.go | 2 +- 7 files changed, 34 insertions(+), 5 deletions(-) diff --git a/migrator.go b/migrator.go index 0e01f567..3d2b032b 100644 --- a/migrator.go +++ b/migrator.go @@ -87,6 +87,8 @@ type Migrator interface { DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. + MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index 49bc9371..64a5a4b5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -27,6 +27,8 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // TODO:? Create const vars for raw sql queries ? +var _ gorm.Migrator = (*Migrator)(nil) + // Migrator m struct type Migrator struct { Config @@ -539,6 +541,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique && !field.Unique { + return m.DB.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return m.DB.Migrator().CreateConstraint(value, constraint) + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) diff --git a/schema/naming.go b/schema/naming.go index a2a0150a..e6fb81b2 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -19,6 +19,7 @@ type Namer interface { RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string + UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer @@ -26,6 +27,8 @@ type Replacer interface { Replace(name string) string } +var _ Namer = (*NamingStrategy)(nil) + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string @@ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } +// UniqueName generate unique constraint name +func (ns NamingStrategy) UniqueName(table, column string) string { + return ns.formatName("uni", table, ns.toDBName(column)) +} + func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 6befb5f2..103da032 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) { t.Fatalf("failed to create items, got error: %v", err) } - tx = tx.Debug() - // test replace if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ Logo: "updated logo", diff --git a/tests/count_test.go b/tests/count_test.go index b0dfb0b5..4449515b 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) { } var count2 int64 - if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) } if count2 != 2 { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7304e350..3ff86492 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) { }, } - DB = DB.Debug() for _, test := range tests { t.Run(test.name, func(t *testing.T) { actual := Org{} diff --git a/tests/update_test.go b/tests/update_test.go index a3fb7015..b719cc45 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) { saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { var newOwner TokenOwner if err := DB.Transaction(func(tx *gorm.DB) error { - if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { return err } if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { From d2fb7a942b8d44f9ad7f6f5bc6f9f99ddcebc95a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 7 Nov 2023 10:19:41 +0800 Subject: [PATCH 194/231] chore(logger): optimize (#6675) * chore(logger): optimize * chore(logger): optimize --- logger/logger.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index aa0060bc..253f0325 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -69,7 +69,7 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to io.Discard + // Discard logger will print any log to io.Discard Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ @@ -78,7 +78,7 @@ var ( IgnoreRecordNotFoundError: false, Colorful: true, }) - // Recorder Recorder logger records running SQL into a recorder instance + // Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) @@ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { +// +//nolint:cyclop +func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= Silent { return } @@ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } } -// Trace print sql message -func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { +// ParamsFilter filter params +func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Config.ParameterizedQueries { return sql, nil } @@ -198,8 +200,8 @@ type traceRecorder struct { Err error } -// New new trace recorder -func (l traceRecorder) New() *traceRecorder { +// New trace recorder +func (l *traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } From 40f4afe8c21d96db63174bd501fb61d6e73c5587 Mon Sep 17 00:00:00 2001 From: Kijima Daigo Date: Tue, 7 Nov 2023 11:20:06 +0900 Subject: [PATCH 195/231] docs: fix broken link (#6673) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 85ad3050..745dad60 100644 --- a/README.md +++ b/README.md @@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) From c1e911f6ed8d3d929aebbd39985a33c9ebe3bad7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Nov 2023 18:46:39 +0800 Subject: [PATCH 196/231] Update tests/go.mod --- tests/go.mod | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 5a0aeddd..71079050 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,14 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.4.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 - gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 + gorm.io/driver/mysql v1.5.2 + gorm.io/driver/postgres v1.5.4 gorm.io/driver/sqlite v1.5.4 - gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.4 + gorm.io/driver/sqlserver v1.5.2 + gorm.io/gorm v1.25.5 ) require ( @@ -19,12 +19,14 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jackc/pgx/v5 v5.5.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 From 3207ad6033aad5e76c6c9d578ef663032765e484 Mon Sep 17 00:00:00 2001 From: FangSqing <148066072+FangSqing@users.noreply.github.com> Date: Wed, 15 Nov 2023 21:32:56 +0800 Subject: [PATCH 197/231] map insert support return increment id (#6662) --- callbacks/create.go | 70 +++++++++++++---- schema/field.go | 4 +- tests/create_test.go | 180 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 17 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index f0b78139..b1488b08 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && - db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) + if db.RowsAffected == 0 { + return + } + + var ( + pkField *schema.Field + pkFieldName = "@id" + ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: + if pkField == nil { return } @@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID -= pkField.AutoIncrementIncrement } } } else { @@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } diff --git a/schema/field.go b/schema/field.go index dd08e056..657e0a4b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -49,6 +49,8 @@ const ( Bytes DataType = "bytes" ) +const DefaultAutoIncrementIncrement int64 = 1 + // Field is the representation of model schema's field type Field struct { Name string @@ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], - AutoIncrementIncrement: 1, + AutoIncrementIncrement: DefaultAutoIncrementIncrement, } for field.IndirectFieldType.Kind() == reflect.Ptr { diff --git a/tests/create_test.go b/tests/create_test.go index 02613b72..d9b54b7f 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "fmt" "regexp" "testing" "time" @@ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { } } -func TestCreateOnConfilctWithDefalutNull(t *testing.T) { +func TestCreateOnConflictWithDefaultNull(t *testing.T) { type OnConfilctUser struct { ID string Name string `gorm:"default:null"` @@ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) { AssertEqual(t, u2.Email, "on-confilct-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } + +func TestCreateFromMapWithoutPK(t *testing.T) { + if !isMysql() { + t.Skipf("This test case skipped, because of only supportting for mysql") + } + + // case 1: one record, create from map[string]interface{} + mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} + if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue1["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + var idVal int64 + _, ok := mapValue1["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue1["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result1.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case2: one record, create from *map[string]interface{} + mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} + if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue2["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + _, ok = mapValue2["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue2["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result2.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: records + values := []map[string]interface{}{ + {"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, + } + + beforeLen := len(values) + if err := DB.Model(&User{}).Create(&values).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + // mariadb with returning, values will be appended with id map + if len(values) == beforeLen*2 { + t.Skipf("This test case skipped, because the db supports returning") + } + + for i := range values { + v, ok := values[i]["id"] + if !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result User + if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + if int64(result.ID) != v.(int64) { + t.Fatal("failed to create data from map with table, @id != id") + } + } +} + +func TestCreateFromMapWithTable(t *testing.T) { + if !isMysql() { + t.Skipf("This test case skipped, because of only supportting for mysql") + } + tableDB := DB.Table("`users`") + + // case 1: create from map[string]interface{} + record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + if err := tableDB.Create(record).Error; err != nil { + t.Fatalf("failed to create data from map with table, got error: %v", err) + } + + if _, ok := record["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if int64(res["id"].(uint64)) != record["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 2: create from *map[string]interface{} + record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} + tableDB2 := DB.Table("users") + if err := tableDB2.Create(&record1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + if _, ok := record1["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res1 map[string]interface{} + if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if int64(res1["id"].(uint64)) != record1["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: create from []map[string]interface{} + records := []map[string]interface{}{ + {"name": "create_from_map_with_table_2", "age": 19}, + {"name": "create_from_map_with_table_3", "age": 20}, + } + + tableDB = DB.Table("users") + if err := tableDB.Create(&records).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + if _, ok := records[0]["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + if _, ok := records[1]["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res2 map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var res3 map[string]interface{} + if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + if int64(res2["id"].(uint64)) != records[0]["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + if int64(res3["id"].(uint64)) != records[1]["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } +} From f0af94cd167c03af78b80586f7628bc8fb470577 Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Wed, 6 Sep 2023 16:13:47 +0200 Subject: [PATCH 198/231] add test to show that update from works --- tests/update_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/update_test.go b/tests/update_test.go index b719cc45..9eb9dbfc 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -882,3 +882,52 @@ func TestSaveWithHooks(t *testing.T) { t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) } } + +// only postgres, sqlserver, sqlite support update from +func TestUpdateFrom(t *testing.T) { + if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" { + return + } + + users := []*User{ + GetUser("update-from-1", Config{Account: true}), + GetUser("update-from-2", Config{Account: true}), + GetUser("update-from-3", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if users[0].ID == 0 { + t.Fatalf("user's primary value should not zero, %v", users[0].ID) + } else if users[0].UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", users[0].UpdatedAt) + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number = ? AND accounts.deleted_at IS NULL", users[0].Account.Number).Update("name", "franco").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + var result User + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } else if result.UpdatedAt.UnixNano() == users[0].UpdatedAt.UnixNano() { + t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, users[0].UpdatedAt) + } else if result.Name != "franco" { + t.Errorf("user's name should be updated") + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number IN ? AND accounts.deleted_at IS NULL", []string{users[0].Account.Number, users[1].Account.Number}).Update("name", gorm.Expr("accounts.number")).RowsAffected; rowsAffected != 2 { + t.Errorf("should update two records, but got %v", rowsAffected) + } + + var results []User + if err := DB.Preload("Account").Find(&results, []uint{users[0].ID, users[1].ID}).Error; err != nil { + t.Errorf("Not error should happen when finding users, but got %v", err) + } + + for _, user := range results { + if user.Name != user.Account.Number { + t.Errorf("user's name should be equal to the account's number %v, but got %v", user.Account.Number, user.Name) + } + } +} From 2fb4928aa873bef89d016e0af2c0724b44725043 Mon Sep 17 00:00:00 2001 From: BugKillerPro <43431429+BugKillerPro@users.noreply.github.com> Date: Fri, 15 Dec 2023 16:31:23 +0800 Subject: [PATCH 199/231] refactor: Resolve implicit memory aliasing in for loop (#6730) --- schema/schema_helper_test.go | 4 ++-- schema/schema_test.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 605aa03a..bc326686 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) } - for _, f := range relation.JoinTable.Fields { - checkSchemaField(t, r.JoinTable, &f, nil) + for i := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 5bc0fb83..45e152e9 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true @@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true From b9ebdb13c777029c03aa5c6c1581f579cdf1f79a Mon Sep 17 00:00:00 2001 From: Maciej Laskowski Date: Fri, 15 Dec 2023 16:32:56 +0800 Subject: [PATCH 200/231] Making locking parameters more intuitive (#6719) * Making locking parameters more intuitive * remove dedicated type --- clause/locking.go | 7 +++++++ clause/locking_test.go | 10 +++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/clause/locking.go b/clause/locking.go index 290aac92..2bc48ceb 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -1,5 +1,12 @@ package clause +const ( + LockingStrengthUpdate = "UPDATE" + LockingStrengthShare = "SHARE" + LockingOptionsSkipLocked = "SKIP LOCKED" + LockingOptionsNoWait = "NOWAIT" +) + type Locking struct { Strength string Table Table diff --git a/clause/locking_test.go b/clause/locking_test.go index 0e607312..e45c8e7d 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -14,17 +14,21 @@ func TestLocking(t *testing.T) { Vars []interface{} }{ { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, "SELECT * FROM `users` FOR UPDATE", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR SHARE OF `users`", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}}, + "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil, + }, } for idx, result := range results { From a2cac75218c9844d0b43832e2d2a3c35f9700406 Mon Sep 17 00:00:00 2001 From: Alexis Viscogliosi Date: Fri, 15 Dec 2023 09:36:08 +0100 Subject: [PATCH 201/231] feature: bring custom type and id column name to polymorphism (#6716) * feature: bring custom type and id column name to polymorphism * relationship: better returns for hasPolymorphicRelation * fix: tests --- schema/relationship.go | 67 +++++++--- schema/relationship_test.go | 187 ++++++++++++++++++++++++++++ tests/associations_has_many_test.go | 11 +- tests/helper_test.go | 26 ++-- tests/migrate_test.go | 85 ++++++++----- tests/tests_test.go | 2 +- utils/tests/models.go | 10 +- 7 files changed, 335 insertions(+), 53 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index e03dcc52..57167859 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return nil } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - schema.buildPolymorphicRelation(relation, field, polymorphic) + if hasPolymorphicRelation(field.TagSettings) { + schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { @@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, + field.Name) } } @@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +// hasPolymorphicRelation check if has polymorphic relation +// 1. `POLYMORPHIC` tag +// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag +func hasPolymorphicRelation(tagSettings map[string]string) bool { + if _, ok := tagSettings["POLYMORPHIC"]; ok { + return true + } + + _, hasType := tagSettings["POLYMORPHICTYPE"] + _, hasId := tagSettings["POLYMORPHICID"] + + return hasType && hasId +} + func (schema *Schema) setRelation(relation *Relationship) { // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { @@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) { // OwnerID int // OwnerType string // } -func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { + polymorphic := field.TagSettings["POLYMORPHIC"] + relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + Value: schema.Table, } + var ( + typeName = polymorphic + "Type" + typeId = polymorphic + "ID" + ) + + if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { + typeName = strings.TrimSpace(value) + } + + if value, ok := field.TagSettings["POLYMORPHICID"]; ok { + typeId = strings.TrimSpace(value) + } + + relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] + relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, + schema, field.Name) } } if primaryKeyField == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", + relation.FieldSchema, schema, field.Name) return } @@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many @@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", + schema, field.Name) } } @@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", + strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, + strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 1eb66bb4..23d79bbb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -577,6 +577,193 @@ func TestEmbeddedHas(t *testing.T) { }) } +func TestPolymorphic(t *testing.T) { + t.Run("has one", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with only polymorphic type", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) +} + func TestEmbeddedBelongsTo(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index c31c4b40..b8e8ff5e 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), - *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-3", Config{Toys: 4}), } @@ -430,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { // Count AssertAssociationCount(t, users, "Toys", 6, "") + AssertAssociationCount(t, users, "Tools", 2, "") // Find var toys []Toy @@ -437,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { t.Errorf("toys count should be %v, but got %v", 6, len(toys)) } + // Find Tools (polymorphic with custom type and id) + var tools []Tools + DB.Model(&users).Association("Tools").Find(&tools) + + if len(tools) != 2 { + t.Errorf("tools count should be %v, but got %v", 2, len(tools)) + } + // Append DB.Model(&users).Association("Toys").Append( &Toy{Name: "toy-slice-append-1"}, diff --git a/tests/helper_test.go b/tests/helper_test.go index 1a4874ee..feb67f9e 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -23,6 +23,7 @@ type Config struct { Languages int Friends int NamedPet bool + Tools int } func GetUser(name string, config Config) *User { @@ -47,6 +48,10 @@ func GetUser(name string, config Config) *User { user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) } + for i := 0; i < config.Tools; i++ { + user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)}) + } + if config.Company { user.Company = Company{Name: "company-" + name} } @@ -118,11 +123,13 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", + "CompanyID", "ManagerID", "Active") } } - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", + "ManagerID", "Active") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") @@ -133,7 +140,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var account Account db(unscoped).First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", + "Number") } } }) @@ -193,8 +201,10 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { } else { var manager User db(unscoped).First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") - AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) @@ -215,7 +225,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) @@ -250,7 +261,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { }) for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index cfd3e0ac..28fa315b 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -18,7 +18,7 @@ import ( ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") @@ -34,7 +34,7 @@ func TestMigrate(t *testing.T) { if tables, err := DB.Migrator().GetTables(); err != nil { t.Fatalf("Failed to get database all tables, but got error %v", err) } else { - for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} { hasTable := false for _, t2 := range tables { if t2 == t1 { @@ -93,7 +93,8 @@ func TestAutoMigrateInt8PG(t *testing.T) { Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { - t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", + sql) } }, } @@ -432,40 +433,50 @@ func TestTiDBMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !ok || length != 100 { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !ok || v != "my age" { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": // Code2 string `gorm:"comment:my code2;default:hello"` if v, ok := columnType.DefaultValue(); !ok || v != "hello" { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } } } @@ -497,7 +508,8 @@ func TestTiDBMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -561,36 +573,45 @@ func TestMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { - t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code3": // TODO @@ -627,7 +648,8 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -1555,7 +1577,8 @@ func TestMigrateIgnoreRelations(t *testing.T) { func TestMigrateView(t *testing.T) { DB.Save(GetUser("joins-args-db", Config{Pets: 2})) - if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + if err := DB.Migrator().CreateView("invalid_users_pets", + gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { t.Fatalf("no view should be created, got %v", err) } @@ -1624,17 +1647,20 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "string_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } case "smallint_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } } } @@ -1659,7 +1685,8 @@ func TestTableType(t *testing.T) { DB.Migrator().DropTable(&City{}) - if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + if err := DB.Set("gorm:table_options", + fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities tables, got error: %v", err) } diff --git a/tests/tests_test.go b/tests/tests_test.go index f9c6cab5..a127734e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -107,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index a4bad2fc..f9f4f50e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -20,7 +20,8 @@ type User struct { Account Account Pets []*Pet NamedPet *Pet - Toys []Toy `gorm:"polymorphic:Owner"` + Toys []Toy `gorm:"polymorphic:Owner"` + Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` CompanyID *int Company Company ManagerID *uint @@ -51,6 +52,13 @@ type Toy struct { OwnerType string } +type Tools struct { + gorm.Model + Name string + CustomID string + Type string +} + type Company struct { ID int Name string From 436cca753cd784969a19477f022db4eb3d84f2ec Mon Sep 17 00:00:00 2001 From: Stephano George Date: Sat, 23 Dec 2023 21:19:41 +0800 Subject: [PATCH 202/231] fix: join and select mytable.* not working (#6761) * fix: select mytable.* not working * fix: select mytable.*: will not match `mytable."*"`. feat: increase readability of code matching table name column name --- statement.go | 22 ++++++++++++++++++---- statement_test.go | 10 ++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index 59c0b772..b24228b2 100644 --- a/statement.go +++ b/statement.go @@ -665,7 +665,21 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) +var matchName = func() func(tableColumn string) (table, column string) { + nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) + return func(tableColumn string) (table, column string) { + if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { + table = matches[1] + star := matches[2] + columnName := matches[3] + if star != "" { + return table, star + } + return table, columnName + } + return "", "" + } +}() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { @@ -686,13 +700,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = result - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") { - if matches[2] == "*" { + } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { + if col == "*" { for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else { - results[matches[2]] = result + results[col] = result } } else { results[column] = result diff --git a/statement_test.go b/statement_test.go index 648bc875..0995d547 100644 --- a/statement_test.go +++ b/statement_test.go @@ -56,9 +56,15 @@ func TestNameMatcher(t *testing.T) { "`name_1`": {"", "name_1"}, "`Name_1`": {"", "Name_1"}, "`Table`.`nAme`": {"Table", "nAme"}, + "my_table.*": {"my_table", "*"}, + "`my_table`.*": {"my_table", "*"}, + "User__Company.*": {"User__Company", "*"}, + "`User__Company`.*": {"User__Company", "*"}, + `"User__Company".*`: {"User__Company", "*"}, + `"table"."*"`: {"", ""}, } { - if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { - t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + if table, column := matchName(k); table != v[0] || column != v[1] { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) } } } From 87decced23be0ce21929fe393fc4fa3a936b1ec8 Mon Sep 17 00:00:00 2001 From: iTanken <23544702+iTanken@users.noreply.github.com> Date: Thu, 28 Dec 2023 19:53:36 +0800 Subject: [PATCH 203/231] fix: ExplainSQL using consecutive pairs of escaper in SQL string represents an escaper (#6766) Preventing it from being interpreted as the string terminator. This is a widely used escape mechanism in SQL standards and is applicable in most relational databases. --- logger/sql.go | 10 +++++----- logger/sql_test.go | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 13e5d957..8ce8d8b1 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -79,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper } else { vars[idx] = nullStr } } case []byte: if s := string(v); isPrintable(s) { - vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } @@ -100,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64: vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: - vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -117,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index a82fa546..036ef3a4 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper } func TestExplainSQL(t *testing.T) { @@ -40,7 +40,7 @@ func TestExplainSQL(t *testing.T) { var ( tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") - pwd = password([]byte("pass")) + pwd = password("pass") jsVal = []byte(`{"Name":"test","Val":"test"}`) js = JSON(jsVal) esVal = []byte(`{"Name":"test","Val":"test"}`) @@ -57,13 +57,13 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", @@ -87,25 +87,25 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, } From 940358e0ddd1e6d5e893184943a4fa71a0c28521 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 12 Jan 2024 16:42:21 +0800 Subject: [PATCH 204/231] Fix tests doesn't follow https://gorm.io/docs/method_chaining.html convention --- chainable_api.go | 27 +++------------------------ clause/where.go | 11 +++++++++++ clause/where_test.go | 4 ++-- finisher_api.go | 8 ++++++-- statement.go | 16 ++++++++-------- tests/go.mod | 10 +++++----- tests/query_test.go | 4 ++-- tests/scopes_test.go | 10 +++++++--- tests/sql_builder_test.go | 2 +- 9 files changed, 45 insertions(+), 47 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 3dc7256e..1ec9b865 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -367,33 +367,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } func (db *DB) executeScopes() (tx *DB) { - tx = db.getInstance() scopes := db.Statement.scopes - if len(scopes) == 0 { - return tx - } - tx.Statement.scopes = nil - - conditions := make([]clause.Interface, 0, 4) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } - + db.Statement.scopes = nil for _, scope := range scopes { - tx = scope(tx) - if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { - conditions = append(conditions, cs.Expression.(clause.Interface)) - cs.Expression = nil - tx.Statement.Clauses["WHERE"] = cs - } + db = scope(db) } - - for _, condition := range conditions { - tx.Statement.AddClause(condition) - } - return tx + return db } // Preload preload associations with given conditions diff --git a/clause/where.go b/clause/where.go index a29401cf..46d0b319 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,6 +21,12 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } + // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { @@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } + if len(exprs) == 1 { + if andCondition, ok := exprs[0].(AndConditions); ok { + exprs = andCondition.Exprs + } + } return NotConditions{Exprs: exprs} } diff --git a/clause/where_test.go b/clause/where_test.go index 35e3dbee..aa9d06eb 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -63,7 +63,7 @@ func TestWhere(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, - "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + "SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?", []interface{}{18, "jinzhu"}, }, { @@ -94,7 +94,7 @@ func TestWhere(t *testing.T) { clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), }, }}, - "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + "SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?", []interface{}{"1", 100}, }, { diff --git a/finisher_api.go b/finisher_api.go index f80aa6c0..f97571ed 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -376,8 +376,12 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { + for i := 0; i < len(exprs); i++ { + expr := exprs[i] + + if eq, ok := expr.(clause.AndConditions); ok { + exprs = append(exprs, eq.Exprs...) + } else if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: assigns[column] = eq.Value diff --git a/statement.go b/statement.go index b24228b2..ae79aa32 100644 --- a/statement.go +++ b/statement.go @@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case *DB: v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -334,13 +334,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else { + } else if cs.Expression != nil { conds = append(conds, cs.Expression) } - if v.Statement == stmt { - cs.Expression = nil - stmt.Statement.Clauses["WHERE"] = cs - } } case map[interface{}]interface{}: for i, j := range v { @@ -451,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + return []clause.Expression{clause.And(conds...)} } - return conds + return nil } } @@ -461,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } - return conds + if len(conds) > 0 { + return []clause.Expression{clause.And(conds...)} + } + return nil } // Build build sql with clauses names diff --git a/tests/go.mod b/tests/go.mod index 71079050..07fedc45 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,7 +3,7 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.5.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 gorm.io/driver/mysql v1.5.2 @@ -18,12 +18,12 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.18 // indirect + github.com/mattn/go-sqlite3 v1.14.19 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect - golang.org/x/crypto v0.15.0 // indirect + golang.org/x/crypto v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/tests/query_test.go b/tests/query_test.go index 5728378d..cadf7164 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1118,12 +1118,12 @@ func TestSearchWithStruct(t *testing.T) { } result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 52c6b37b..84aeb990 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -84,7 +84,9 @@ func TestComplexScopes(t *testing.T) { queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, - func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + func(d *gorm.DB) *gorm.DB { + return d.Where(DB.Or("b = 2").Or("c = 3")) + }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, @@ -93,7 +95,9 @@ func TestComplexScopes(t *testing.T) { queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Where("z = 0").Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, - func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + func(d *gorm.DB) *gorm.DB { + return d.Or(DB.Where("b = 2").Or("c = 3")) + }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, @@ -104,7 +108,7 @@ func TestComplexScopes(t *testing.T) { func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, func(d *gorm.DB) *gorm.DB { return d. - Or(d.Scopes( + Or(DB.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, )). diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 022e0495..0c204db4 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) { sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) }) - assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) // last and unscoped sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { From 0123dd45094295fade41e13550cd305eb5e3a848 Mon Sep 17 00:00:00 2001 From: Jacky Date: Fri, 12 Jan 2024 18:09:22 +0900 Subject: [PATCH 205/231] fix: ignore .gen.go suffix in logger to get the real caller when using gen #6697 (#6785) --- utils/utils.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index c8fec5b0..a4d8ac25 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -35,7 +35,8 @@ func FileWithLineNum() string { // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && + !strings.HasSuffix(file, ".gen.go") { return file + ":" + strconv.FormatInt(int64(line), 10) } } From e043924fe79a38d0f3bf5df81fd795836c809415 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:34:20 +0800 Subject: [PATCH 206/231] chore(deps): bump actions/cache from 3 to 4 (#6802) Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 380231b9..af471d20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -73,7 +73,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -116,7 +116,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -159,7 +159,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -202,7 +202,7 @@ jobs: uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -235,7 +235,7 @@ jobs: - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} From 418ee3fc1939d87a05bbb8ac6d7c7223e2c4571f Mon Sep 17 00:00:00 2001 From: black-06 Date: Mon, 29 Jan 2024 11:34:57 +0800 Subject: [PATCH 207/231] fix: preload shouldn't overwrite the value of join (#6771) * fix: preload shouldn't overwrite the value of join * fix lint * fix: join may automatically add nested query --- callbacks/preload.go | 74 +++++++++++++++++++++++++++++++++++++------ callbacks/query.go | 33 +++++-------------- tests/preload_test.go | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 35 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 15669c84..25ecfe76 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "sort" "strings" "gorm.io/gorm" @@ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { return names } -func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { - if relationships == nil { - return nil +// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. +// If the current relationship is embedded or joined, current query will be ignored. +// +//nolint:cyclop +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { + preloadMap := parsePreloadMap(db.Statement.Schema, preloads) + + // avoid random traversal of the map + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) } - preloadMap := parsePreloadMap(s, preloads) - for name := range preloadMap { - if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { - if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + sort.Strings(preloadNames) + + isJoined := func(name string) (joined bool, nestedJoins []string) { + for _, join := range joins { + if _, ok := relationships.Relations[join]; ok && name == join { + joined = true + continue + } + joinNames := strings.SplitN(join, ".", 2) + if len(joinNames) == 2 { + if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + joined = true + nestedJoins = append(nestedJoins, joinNames[1]) + } + } + } + return joined, nestedJoins + } + + for _, name := range preloadNames { + if relations := relationships.EmbeddedRelations[name]; relations != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { - if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { - return err + if joined, nestedJoins := isJoined(name); joined { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } else { + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx.Statement.ReflectValue = db.Statement.ReflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { + return err + } } } else { - return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) } } return nil } +func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { + tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if err := tx.Statement.Parse(dest); err != nil { + tx.AddError(err) + return tx + } + tx.Statement.ReflectValue = reflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + return tx +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/callbacks/query.go b/callbacks/query.go index e89dd199..2a82eaba 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -3,7 +3,6 @@ package callbacks import ( "fmt" "reflect" - "sort" "strings" "gorm.io/gorm" @@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) { } db.Statement.AddClause(fromClause) - db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } @@ -272,38 +270,23 @@ func Preload(db *gorm.DB) { return } - preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) - preloadNames := make([]string, 0, len(preloadMap)) - for key := range preloadMap { - preloadNames = append(preloadNames, key) + joins := make([]string, 0, len(db.Statement.Joins)) + for _, join := range db.Statement.Joins { + joins = append(joins, join.Name) } - sort.Strings(preloadNames) - preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) - db.Statement.Settings.Range(func(k, v interface{}) bool { - preloadDB.Statement.Settings.Store(k, v) - return true - }) - - if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) + if tx.Error != nil { return } - preloadDB.Statement.ReflectValue = db.Statement.ReflectValue - preloadDB.Statement.Unscoped = db.Statement.Unscoped - for _, name := range preloadNames { - if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { - db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) - } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) - } else { - db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) - } - } + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } func AfterQuery(db *gorm.DB) { + // clear the joins after query because preload need it + db.Statement.Joins = nil if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/tests/preload_test.go b/tests/preload_test.go index 3ff86492..26b08d7d 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -307,6 +307,63 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { CheckUserUnscoped(t, *user6, user) } +func TestNestedPreloadWithNestedJoin(t *testing.T) { + type ( + Preload struct { + ID uint + Value string + NestedID uint + } + Join struct { + ID uint + Value string + NestedID uint + } + Nested struct { + ID uint + Preloads []*Preload + Join Join + ValueID uint + } + Value struct { + ID uint + Name string + Nested Nested + } + ) + + DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) + DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) + + value := Value{ + Name: "value", + Nested: Nested{ + Preloads: []*Preload{ + {Value: "p1"}, {Value: "p2"}, + }, + Join: Join{Value: "j1"}, + }, + } + if err := DB.Create(&value).Error; err != nil { + t.Errorf("failed to create value, got err: %v", err) + } + + var find1 Value + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find1, value) + + var find2 Value + // Joins will automatically add Nested queries. + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find2, value) +} + func TestEmbedPreload(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` From 46816ad31dde63292233d94ebeb4fd188d299555 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sun, 4 Feb 2024 15:49:19 +0800 Subject: [PATCH 208/231] refactor: distinguish between Unique and UniqueIndex (#6386) * refactor: distinguish between UniqueIndex and Index * add test * add ParseIndex test * modify unique to constraint * modify unique to constraint * fix MigrateColumnUnique * fix test * fix unit test * update test mod * add MigrateColumnUnique to Migrator interface * fix format lint * add comment * go mod tidy * revert: revert MigrateColumn * resolve conflicts --- migrator/migrator.go | 114 +++++++------- schema/check.go | 35 ----- schema/constraint.go | 66 ++++++++ schema/{check_test.go => constraint_test.go} | 31 +++- schema/field.go | 6 + schema/index.go | 6 +- schema/index_test.go | 149 ++++++++++++++----- schema/interfaces.go | 6 + schema/relationship.go | 26 ++++ 9 files changed, 312 insertions(+), 127 deletions(-) delete mode 100644 schema/check.go create mode 100644 schema/constraint.go rename schema/{check_test.go => constraint_test.go} (59%) diff --git a/migrator/migrator.go b/migrator/migrator.go index 64a5a4b5..d97fbf35 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -110,15 +110,20 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } +func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { + queryTx = m.DB.Session(&gorm.Session{}) + execTx = queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + return queryTx, execTx +} + // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - queryTx := m.DB.Session(&gorm.Session{}) - execTx := queryTx - if m.DB.DryRun { - queryTx.DryRun = false - execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) - } + queryTx, execTx := m.GetQueryAndExecTx() if !queryTx.Migrator().HasTable(value) { if err := execTx.Migrator().CreateTable(value); err != nil { return err @@ -268,7 +273,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) + sql, vars := constraint.Build() createTableSQL += sql + "," values = append(values, vars...) } @@ -276,6 +281,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } + for _, uni := range stmt.Schema.ParseUniqueConstraints() { + createTableSQL += "CONSTRAINT ? UNIQUE (?)," + values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) + } + for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) @@ -439,6 +449,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if field.IgnoreMigration { + return nil + } + // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -499,7 +513,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } // check unique - if unique, ok := columnType.Unique(); ok && unique != field.Unique { + if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") { // not primary key if !field.PrimaryKey { alterColumn = true @@ -630,37 +644,36 @@ func (m Migrator) DropView(name string) error { return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } -func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { - sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete +// GuessConstraintAndTable guess statement's constraint and it's table based on name +// +// Deprecated: use GuessConstraintInterfaceAndTable instead. +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + switch c := constraint.(type) { + case *schema.Constraint: + return c, nil, table + case *schema.CheckConstraint: + return nil, c, table + default: + return nil, nil, table } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) - return } -// GuessConstraintAndTable guess statement's constraint and it's table based on name -func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { +// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name +// nolint:cyclop +func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { - return nil, nil, stmt.Table + return nil, stmt.Table } checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return nil, &chk, stmt.Table + return &chk, stmt.Table + } + + uniqueConstraints := stmt.Schema.ParseUniqueConstraints() + if uni, ok := uniqueConstraints[name]; ok { + return &uni, stmt.Table } getTable := func(rel *schema.Relationship) string { @@ -675,7 +688,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } @@ -683,40 +696,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for k := range checkConstraints { if checkConstraints[k].Field == field { v := checkConstraints[k] - return nil, &v, stmt.Table + return &v, stmt.Table + } + } + + for k := range uniqueConstraints { + if uniqueConstraints[k].Field == field { + v := uniqueConstraints[k] + return &v, stmt.Table } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } } - return nil, nil, stmt.Schema.Table + return nil, stmt.Schema.Table } // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - if chk != nil { - return m.DB.Exec( - "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, - ).Error - } - + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } - sql, values := buildConstraint(constraint) + sql, values := constraint.Build() return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - return nil }) } @@ -724,11 +736,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) @@ -739,11 +749,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Raw( diff --git a/schema/check.go b/schema/check.go deleted file mode 100644 index 89e732d3..00000000 --- a/schema/check.go +++ /dev/null @@ -1,35 +0,0 @@ -package schema - -import ( - "regexp" - "strings" -) - -// reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") - -type Check struct { - Name string - Constraint string // length(phone) >= 10 - *Field -} - -// ParseCheckConstraints parse schema check constraints -func (schema *Schema) ParseCheckConstraints() map[string]Check { - checks := map[string]Check{} - for _, field := range schema.FieldsByDBName { - if chk := field.TagSettings["CHECK"]; chk != "" { - names := strings.Split(chk, ",") - if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { - checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} - } else { - if names[0] == "" { - chk = strings.Join(names[1:], ",") - } - name := schema.namer.CheckerName(schema.Table, field.DBName) - checks[name] = Check{Name: name, Constraint: chk, Field: field} - } - } - } - return checks -} diff --git a/schema/constraint.go b/schema/constraint.go new file mode 100644 index 00000000..5f6beb89 --- /dev/null +++ b/schema/constraint.go @@ -0,0 +1,66 @@ +package schema + +import ( + "regexp" + "strings" + + "gorm.io/gorm/clause" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") + +type CheckConstraint struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +func (chk *CheckConstraint) GetName() string { return chk.Name } + +func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { + checks := map[string]CheckConstraint{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} + +type UniqueConstraint struct { + Name string + Field *Field +} + +func (uni *UniqueConstraint) GetName() string { return uni.Name } + +func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} +} + +// ParseUniqueConstraints parse schema unique constraints +func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { + uniques := make(map[string]UniqueConstraint) + for _, field := range schema.Fields { + if field.Unique { + name := schema.namer.UniqueName(schema.Table, field.DBName) + uniques[name] = UniqueConstraint{Name: name, Field: field} + } + } + return uniques +} diff --git a/schema/check_test.go b/schema/constraint_test.go similarity index 59% rename from schema/check_test.go rename to schema/constraint_test.go index eda043b7..6fcb1b85 100644 --- a/schema/check_test.go +++ b/schema/constraint_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserCheck struct { @@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) { t.Fatalf("failed to parse user check, got error %v", err) } - results := map[string]schema.Check{ + results := map[string]schema.CheckConstraint{ "name_checker": { Name: "name_checker", Constraint: "name <> 'jinzhu'", @@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) { } } } + +func TestParseUniqueConstraints(t *testing.T) { + type UserUnique struct { + Name1 string `gorm:"unique"` + Name2 string `gorm:"uniqueIndex"` + } + + user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + constraints := user.ParseUniqueConstraints() + + results := map[string]schema.UniqueConstraint{ + "uni_user_uniques_name1": { + Name: "uni_user_uniques_name1", + Field: &schema.Field{Name: "Name1", Unique: true}, + }, + } + for k, result := range results { + v, ok := constraints[k] + if !ok { + t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints) + } + tests.AssertObjEqual(t, result, v, "Name") + tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex") + } +} diff --git a/schema/field.go b/schema/field.go index 657e0a4b..91e4c0ab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -89,6 +89,12 @@ type Field struct { Set func(context.Context, reflect.Value, interface{}) error Serializer SerializerInterface NewValuePool FieldNewValuePool + + // In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable. + // When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique. + // It causes field unnecessarily migration. + // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. + UniqueIndex string } func (field *Field) BindName() string { diff --git a/schema/index.go b/schema/index.go index f5ac5dd2..f4f36751 100644 --- a/schema/index.go +++ b/schema/index.go @@ -13,8 +13,8 @@ type Index struct { Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string - Option string // WITH PARSER parser_name - Fields []IndexOption + Option string // WITH PARSER parser_name + Fields []IndexOption // Note: IndexOption's Field maybe the same } type IndexOption struct { @@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index { } for _, index := range indexes { if index.Class == "UNIQUE" && len(index.Fields) == 1 { - index.Fields[0].Field.Unique = true + index.Fields[0].Field.UniqueIndex = index.Name } } return indexes diff --git a/schema/index_test.go b/schema/index_test.go index 890327de..2f1e36af 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -1,11 +1,11 @@ package schema_test import ( - "reflect" "sync" "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserIndex struct { @@ -19,6 +19,7 @@ type UserIndex struct { OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` + Name8 string `gorm:"index:,length:10;index:,collate:utf8"` // Composite Index: Flattened structure. Data0A string `gorm:"index:,composite:comp_id0"` @@ -65,7 +66,7 @@ func TestParseIndex(t *testing.T) { "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", @@ -81,7 +82,7 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", @@ -102,18 +103,27 @@ func TestParseIndex(t *testing.T) { }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "type": { Name: "type", Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, + "idx_user_indices_name8": { + Name: "idx_user_indices_name8", + Type: "", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "Name8"}, Length: 10}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "Name8"}, Collate: "utf8"}, + }, + }, "idx_user_indices_comp_id0": { Name: "idx_user_indices_comp_id0", Type: "", @@ -146,40 +156,109 @@ func TestParseIndex(t *testing.T) { }, } - indices := user.ParseIndexes() + CheckIndices(t, results, user.ParseIndexes()) +} - for k, result := range results { - v, ok := indices[k] - if !ok { - t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) - } +func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { + type IndexTest struct { + FieldA string `gorm:"unique;index"` // unique and index + FieldB string `gorm:"unique"` // unique - for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { - if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { - t.Errorf( - "index %v %v should equal, expects %v, got %v", - k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), - ) - } - } + FieldC string `gorm:"index:,unique"` // uniqueIndex + FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index - for idx, ef := range result.Fields { - rf := v.Fields[idx] - if rf.Field.Name != ef.Field.Name { - t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) - } - if rf.Field.Unique != ef.Field.Unique { - t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique) - } + FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex + FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"` - for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { - if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { - t.Errorf( - "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, - reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), - ) - } + FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index + FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"` + + FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex + + FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + } + indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index, got error %v", err) + } + indices := indexSchema.ParseIndexes() + CheckIndices(t, map[string]schema.Index{ + "idx_index_tests_field_a": { + Name: "idx_index_tests_field_a", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}}, + }, + "idx_index_tests_field_c": { + Name: "idx_index_tests_field_c", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}}, + }, + "idx_index_tests_field_d": { + Name: "idx_index_tests_field_d", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldD"}}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "FieldD"}}, + }, + }, + "uniq_field_e1_e2": { + Name: "uniq_field_e1_e2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldE1"}}, + {Field: &schema.Field{Name: "FieldE2"}}, + }, + }, + "idx_index_tests_field_f1": { + Name: "idx_index_tests_field_f1", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}}, + }, + "uniq_field_f1_f2": { + Name: "uniq_field_f1_f2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldF1"}}, + {Field: &schema.Field{Name: "FieldF2"}}, + }, + }, + "idx_index_tests_field_g": { + Name: "idx_index_tests_field_g", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}}, + }, + "uniq_field_h1_h2": { + Name: "uniq_field_h1_h2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldH1", Unique: true}}, + {Field: &schema.Field{Name: "FieldH2"}}, + }, + }, + }, indices) +} + +func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) { + for k, ei := range expected { + t.Run(k, func(t *testing.T) { + ai, ok := actual[k] + if !ok { + t.Errorf("expected index %q but actual missing", k) + return } - } + tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option") + if len(ei.Fields) != len(ai.Fields) { + t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields)) + return + } + for i, ef := range ei.Fields { + af := ai.Fields[i] + tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length") + } + }) + delete(actual, k) + } + for k := range actual { + t.Errorf("unexpected index %q", k) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index a75a33c0..306d4f4e 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,6 +4,12 @@ import ( "gorm.io/gorm/clause" ) +// ConstraintInterface database constraint interface +type ConstraintInterface interface { + GetName() string + Build() (sql string, vars []interface{}) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string diff --git a/schema/relationship.go b/schema/relationship.go index 57167859..2e94fc2c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -605,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } +// Constraint is ForeignKey Constraint type Constraint struct { Name string Field *Field @@ -616,6 +617,31 @@ type Constraint struct { OnUpdate string } +func (constraint *Constraint) GetName() string { return constraint.Name } + +func (constraint *Constraint) Build() (sql string, vars []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + references := make([]interface{}, 0, len(constraint.References)) + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" { From 9514d5f9e677e0c94d3d4e48d4f08c649f092d9a Mon Sep 17 00:00:00 2001 From: jasonchuan Date: Tue, 6 Feb 2024 10:54:40 +0800 Subject: [PATCH 209/231] let limit and offset use bind parameter (#6806) * let limit and offset use bind parameter * format * format limt_test * try again * fix test case fro connpool * adding driverName for postgres ,if not to do so, the stmt vars will be added a wrong one called pgx.QueryExecModeSimpleProtocol , causing the SQL with limit problem need 1 parameter ,but given two. * delete trunk files * restore the test_test.go * restore test_test.go * driver/postgres->v1.5.5 * change postgres version rollback to 1.5.4 --------- Co-authored-by: chenchuan Co-authored-by: jason_chuan --- clause/limit.go | 6 ++---- clause/limit_test.go | 30 ++++++++++++++++++++---------- tests/connpool_test.go | 10 +++++----- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index abda0055..3edde434 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -1,7 +1,5 @@ package clause -import "strconv" - // Limit limit clause type Limit struct { Limit *int @@ -17,14 +15,14 @@ func (limit Limit) Name() string { func (limit Limit) Build(builder Builder) { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(*limit.Limit)) + builder.AddVar(builder, *limit.Limit) } if limit.Offset > 0 { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) + builder.AddVar(builder, limit.Offset) } } diff --git a/clause/limit_test.go b/clause/limit_test.go index a9fd4e24..96a7e7e6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -22,43 +22,53 @@ func TestLimit(t *testing.T) { Limit: &limit10, Offset: 20, }}, - "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, - "SELECT * FROM `users` LIMIT 0", nil, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit0}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, - "SELECT * FROM `users` LIMIT 0", nil, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit0}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, - "SELECT * FROM `users` OFFSET 20", nil, + "SELECT * FROM `users` OFFSET ?", + []interface{}{20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, - "SELECT * FROM `users` OFFSET 30", nil, + "SELECT * FROM `users` OFFSET ?", + []interface{}{30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, - "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, - "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, - "SELECT * FROM `users` LIMIT 10", nil, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit10}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, - "SELECT * FROM `users` OFFSET 30", nil, + "SELECT * FROM `users` OFFSET ?", + []interface{}{30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, - "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit50, 30}, }, } diff --git a/tests/connpool_test.go b/tests/connpool_test.go index e0e1c771..21a2bad0 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -102,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) { expect: []string{ "SELECT VERSION()", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", }, } From 8fb9a317756bc07dcaaa82d43613ddcf9295c1ad Mon Sep 17 00:00:00 2001 From: black-06 Date: Tue, 6 Feb 2024 19:48:40 +0800 Subject: [PATCH 210/231] refactor: part 2 of distinguish between Unique and UniqueIndex (#6822) --- migrator/migrator.go | 22 ++-- tests/go.mod | 22 ++-- tests/migrate_test.go | 227 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 247 insertions(+), 24 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d97fbf35..ae82f769 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -93,10 +93,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " NOT NULL" } - if field.Unique { - expr.SQL += " UNIQUE" - } - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} @@ -512,14 +508,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && unique != (field.Unique || field.UniqueIndex != "") { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) @@ -548,8 +536,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.DBName) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil diff --git a/tests/go.mod b/tests/go.mod index 07fedc45..136667b7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,28 +3,34 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.5.0 + github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - gorm.io/driver/mysql v1.5.2 - gorm.io/driver/postgres v1.5.4 - gorm.io/driver/sqlite v1.5.4 - gorm.io/driver/sqlserver v1.5.2 - gorm.io/gorm v1.25.5 + github.com/stretchr/testify v1.8.4 + gorm.io/driver/mysql v1.5.4 + gorm.io/driver/postgres v1.5.6 + gorm.io/driver/sqlite v1.5.5 + gorm.io/driver/sqlserver v1.5.3 + gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.1 // indirect + github.com/jackc/pgx/v5 v5.5.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.19 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect golang.org/x/crypto v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 28fa315b..837d92c1 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "database/sql" "fmt" "math/rand" "os" @@ -10,10 +11,15 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) @@ -984,7 +990,8 @@ func TestCurrentTimestamp(t *testing.T) { if err != nil { t.Fatalf("AutoMigrate err:%v", err) } - AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) } @@ -1046,7 +1053,8 @@ func TestUniqueColumn(t *testing.T) { } // not trigger alert column - AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) @@ -1712,3 +1720,218 @@ func TestTableType(t *testing.T) { t.Fatalf("expected comment %s got %s", tblComment, comment) } } + +func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { + const table = "unique_struct" + + checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) { + stmt := &gorm.Statement{DB: DB} + err := stmt.Parse(model) + if err != nil { + t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err) + } + _ = stmt.Schema.ParseIndexes() + field := stmt.Schema.LookUpField(fieldName) + if field == nil { + t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName) + } + if field.Unique != unique { + t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique) + } + if field.UniqueIndex != uniqueIndex { + t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex) + } + } + + type ( // not unique + UniqueStruct1 struct { + Name string `gorm:"size:10"` + } + UniqueStruct2 struct { + Name string `gorm:"size:20"` + } + ) + checkField(&UniqueStruct1{}, "name", false, "") + checkField(&UniqueStruct2{}, "name", false, "") + + type ( // unique + UniqueStruct3 struct { + Name string `gorm:"size:30;unique"` + } + UniqueStruct4 struct { + Name string `gorm:"size:40;unique"` + } + ) + checkField(&UniqueStruct3{}, "name", true, "") + checkField(&UniqueStruct4{}, "name", true, "") + + type ( // uniqueIndex + UniqueStruct5 struct { + Name string `gorm:"size:50;uniqueIndex"` + } + UniqueStruct6 struct { + Name string `gorm:"size:60;uniqueIndex"` + } + UniqueStruct7 struct { + Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + } + ) + checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name") + checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name") + + checkField(&UniqueStruct7{}, "name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + + type UniqueStruct8 struct { // unique and uniqueIndex + Name string `gorm:"size:60;unique;index:my_us8_index,unique;"` + } + checkField(&UniqueStruct8{}, "name", true, "my_us8_index") + + type TestCase struct { + name string + from, to interface{} + checkFunc func(t *testing.T) + } + + checkColumnType := func(t *testing.T, fieldName string, unique bool) { + columnTypes, err := DB.Migrator().ColumnTypes(table) + if err != nil { + t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err) + } + var found gorm.ColumnType + for _, columnType := range columnTypes { + if columnType.Name() == fieldName { + found = columnType + } + } + if found == nil { + t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName) + } + if actualUnique, ok := found.Unique(); !ok || actualUnique != unique { + t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique) + } + } + + checkIndex := func(t *testing.T, expected []gorm.Index) { + indexes, err := DB.Migrator().GetIndexes(table) + if err != nil { + t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err) + } + assert.ElementsMatch(t, expected, indexes) + } + + uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + + var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T) + // UniqueAffectedByUniqueIndex is true + if DB.Dialector.Name() == "mysql" { + uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + checkNotUnique = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, nil) + } + checkUnique = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex}) + } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex}) + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + checkIndex(t, []gorm.Index{mulIndex}) + } + } else { + checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) } + checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + if !DB.Migrator().HasIndex(table, myIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name()) + } + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + if !DB.Migrator().HasIndex(table, mulIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name()) + } + } + } + + tests := []TestCase{ + {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique}, + {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex}, + {name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.from); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + + if DB.Dialector.Name() != "sqlserver" { + // In SQLServer, If an index or constraint depends on the column, + // this column will not be able to run ALTER + // see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205 + // may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106 + tests = []TestCase{ + {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + } + } + + if DB.Dialector.Name() == "mysql" { + compatibilityTests := []TestCase{ + {name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique}, + {name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex}, + } + for _, test := range compatibilityTests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil { + t.Fatalf("failed to create table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + } +} From d81ae6f701e09214ac550a3308428f8434174f5f Mon Sep 17 00:00:00 2001 From: M Dmitry Date: Mon, 19 Feb 2024 03:42:25 +0000 Subject: [PATCH 211/231] Fixed: panic on nullable value with multiple foreign key usage (#6839) See: https://github.com/go-gorm/playground/pull/537 --- utils/utils.go | 6 +++++- utils/utils_test.go | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index a4d8ac25..347a331f 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -74,7 +74,11 @@ func ToStringKey(values ...interface{}) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + results[idx] = "nil" + vv := reflect.ValueOf(v) + if vv.IsValid() && !vv.IsZero() { + results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) + } } } diff --git a/utils/utils_test.go b/utils/utils_test.go index d0486822..8ff42af8 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -48,8 +48,10 @@ func TestToStringKey(t *testing.T) { }{ {[]interface{}{"a"}, "a"}, {[]interface{}{1, 2, 3}, "1_2_3"}, + {[]interface{}{1, nil, 3}, "1_nil_3"}, {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", nil, "3"}}, "[1 3]"}, } for _, c := range cases { if key := ToStringKey(c.values...); key != c.key { From 52404cddbb5f5b18253cebe12fde20b577b1c86d Mon Sep 17 00:00:00 2001 From: Chef <130948588+naruchet@users.noreply.github.com> Date: Tue, 27 Feb 2024 09:48:04 +0700 Subject: [PATCH 212/231] CHORE add unittest test function ConvertMapToValueForCreate (#6846) * CHORE add unittest test function ConvertMapToValueForCreate * CHORE move the test cases located in the files convert_map_test.go and visit_map_test.go into the file helper_test.go. --- callbacks/helper_test.go | 97 +++++++++++++++++++++++++++++++++++++ callbacks/visit_map_test.go | 36 -------------- 2 files changed, 97 insertions(+), 36 deletions(-) create mode 100644 callbacks/helper_test.go delete mode 100644 callbacks/visit_map_test.go diff --git a/callbacks/helper_test.go b/callbacks/helper_test.go new file mode 100644 index 00000000..6b76a415 --- /dev/null +++ b/callbacks/helper_test.go @@ -0,0 +1,97 @@ +package callbacks + +import ( + "reflect" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} + +func TestConvertMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert string value", + input: map[string]interface{}{ + "name": "my name", + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert int value", + input: map[string]interface{}{ + "age": 18, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert float value", + input: map[string]interface{}{ + "score": 99.5, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert bool value", + input: map[string]interface{}{ + "active": true, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input) + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expect %v got %v", tc.expect, actual) + } + }) + } +} diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go deleted file mode 100644 index b1fb86db..00000000 --- a/callbacks/visit_map_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package callbacks - -import ( - "reflect" - "testing" -) - -func TestLoadOrStoreVisitMap(t *testing.T) { - var vm visitMap - var loaded bool - type testM struct { - Name string - } - - t1 := testM{Name: "t1"} - t2 := testM{Name: "t2"} - t3 := testM{Name: "t3"} - - vm = make(visitMap) - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { - t.Fatalf("loaded should be true") - } - - // t1 already exist but t2 not - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { - t.Fatalf("loaded should be true") - } -} From f118e55db5c14f524b8fba3c9ac2924b7db95870 Mon Sep 17 00:00:00 2001 From: Chef <130948588+naruchet@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:22:57 +0700 Subject: [PATCH 213/231] Add unittest test helper function ConvertSliceOfMapToValuesForCreate (#6854) --- callbacks/helper_test.go | 60 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/callbacks/helper_test.go b/callbacks/helper_test.go index 6b76a415..08f94e20 100644 --- a/callbacks/helper_test.go +++ b/callbacks/helper_test.go @@ -95,3 +95,63 @@ func TestConvertMapToValuesForCreate(t *testing.T) { }) } } + +func TestConvertSliceOfMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input []map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert slice of string value", + input: []map[string]interface{}{ + {"name": "my name"}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert slice of int value", + input: []map[string]interface{}{ + {"age": 18}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert slice of float value", + input: []map[string]interface{}{ + {"score": 99.5}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert slice of bool value", + input: []map[string]interface{}{ + {"active": true}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input) + + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expected %v but got %v", tc.expect, actual) + } + }) + } + +} From 3e2c4fc446f0a601124771d012ee167cbb0d7de0 Mon Sep 17 00:00:00 2001 From: tsuba3 Date: Tue, 5 Mar 2024 11:23:51 +0900 Subject: [PATCH 214/231] Fix regression in db.Not introduced in v1.25.6. (#6844) * Fix regression in db.Not introduced in 940358e. * Fix --- clause/where.go | 73 +++++++++++++++++++++++++++++++++----------- clause/where_test.go | 8 +++++ tests/query_test.go | 5 +++ 3 files changed, 69 insertions(+), 17 deletions(-) diff --git a/clause/where.go b/clause/where.go index 46d0b319..9ac78578 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,11 +21,11 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - if len(where.Exprs) == 1 { - if andCondition, ok := where.Exprs[0].(AndConditions); ok { - where.Exprs = andCondition.Exprs - } - } + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { @@ -166,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -193,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/clause/where_test.go b/clause/where_test.go index aa9d06eb..7d5aca1f 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -105,6 +105,14 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", []interface{}{"1", 100}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, + clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, + }}, + "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", + []interface{}{100, 60}, + }, } for idx, result := range results { diff --git a/tests/query_test.go b/tests/query_test.go index cadf7164..e780e3bf 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -554,6 +554,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) { From f17a75242e3522820ede9678aba40b29d83b6301 Mon Sep 17 00:00:00 2001 From: hishope Date: Thu, 7 Mar 2024 14:18:31 +0800 Subject: [PATCH 215/231] Signed-off-by: hishope fix some typos in tests Signed-off-by: hishope --- tests/create_test.go | 28 ++++++++++++++-------------- tests/migrate_test.go | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/create_test.go b/tests/create_test.go index d9b54b7f..5e97a542 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -582,44 +582,44 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { } func TestCreateOnConflictWithDefaultNull(t *testing.T) { - type OnConfilctUser struct { + type OnConflictUser struct { ID string Name string `gorm:"default:null"` Email string Mobile string `gorm:"default:'133xxxx'"` } - err := DB.Migrator().DropTable(&OnConfilctUser{}) + err := DB.Migrator().DropTable(&OnConflictUser{}) AssertEqual(t, err, nil) - err = DB.AutoMigrate(&OnConfilctUser{}) + err = DB.AutoMigrate(&OnConflictUser{}) AssertEqual(t, err, nil) - u := OnConfilctUser{ - ID: "on-confilct-user-id", - Name: "on-confilct-user-name", - Email: "on-confilct-user-email", - Mobile: "on-confilct-user-mobile", + u := OnConflictUser{ + ID: "on-conflict-user-id", + Name: "on-conflict-user-name", + Email: "on-conflict-user-email", + Mobile: "on-conflict-user-mobile", } err = DB.Create(&u).Error AssertEqual(t, err, nil) - u.Name = "on-confilct-user-name-2" - u.Email = "on-confilct-user-email-2" + u.Name = "on-conflict-user-name-2" + u.Email = "on-conflict-user-email-2" u.Mobile = "" err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error AssertEqual(t, err, nil) - var u2 OnConfilctUser + var u2 OnConflictUser err = DB.Where("id = ?", u.ID).First(&u2).Error AssertEqual(t, err, nil) - AssertEqual(t, u2.Name, "on-confilct-user-name-2") - AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Name, "on-conflict-user-name-2") + AssertEqual(t, u2.Email, "on-conflict-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } func TestCreateFromMapWithoutPK(t *testing.T) { if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") + t.Skipf("This test case skipped, because of only supporting for mysql") } // case 1: one record, create from map[string]interface{} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 837d92c1..b25b9da6 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1413,10 +1413,10 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { err = DB.Table("game_users").AutoMigrate(&GameUser1{}) AssertEqual(t, nil, err) - _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count") AssertEqual(t, nil, err) - _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count") AssertEqual(t, nil, err) } From 9efae659cb4712c6446c2e63299f2c0c25d1cca5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 9 Mar 2024 17:31:28 +0800 Subject: [PATCH 216/231] test: namer identifier lenght (#6872) --- tests/table_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/tests/table_test.go b/tests/table_test.go index fa569d32..0d44a15b 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -2,8 +2,10 @@ package tests_test import ( "regexp" + "sync" "testing" + "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" @@ -172,3 +174,88 @@ func TestTableWithNamer(t *testing.T) { t.Errorf("Table with namer, got %v", sql) } } + +func TestPostgresTableWithIdentifierLength(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type LongString struct { + ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"` + } + + t.Run("default", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("naming strategy", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{}, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("namer", func(t *testing.T) { + uname := "custom_unique_name" + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: mockUniqueNamingStrategy{ + UName: uname, + }, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if key != uname { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) +} + +type mockUniqueNamingStrategy struct { + UName string + schema.NamingStrategy +} + +func (a mockUniqueNamingStrategy) UniqueName(table, column string) string { + return a.UName +} From c4c9aa45e32efaf5d63cb7ac9ce66fcca5fc7c00 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sat, 9 Mar 2024 17:39:01 +0800 Subject: [PATCH 217/231] fix(scan.go): reflect.MakeSlice passes in the reflect.Array type (#6880) --- scan.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index 736db4d3..54cd6769 100644 --- a/scan.go +++ b/scan.go @@ -274,12 +274,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - // if the slice cap is externally initialized, the externally initialized slice is directly used here - if reflectValue.Cap() == 0 { - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else if !isArrayKind { - reflectValue.SetLen(0) - db.Statement.ReflectValue.Set(reflectValue) + if !isArrayKind { + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } } From e4e23d26d2d16cbfb575da0bfa9a400e2ad29b92 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sat, 9 Mar 2024 21:27:19 +0800 Subject: [PATCH 218/231] fix: nested preload with join panic when find (#6877) --- callbacks/preload.go | 21 +++++++++++++++++---- tests/preload_test.go | 10 ++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 25ecfe76..cf7a0d2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) - tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { - return err + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) diff --git a/tests/preload_test.go b/tests/preload_test.go index 26b08d7d..14f94139 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/require" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find2, value) + + var finds []Value + err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + require.Len(t, finds, 1) + AssertEqual(t, finds[0], value) } func TestEmbedPreload(t *testing.T) { From 7b1fb0bd732448adad0ea097983fc663700c1978 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Fri, 15 Mar 2024 14:14:48 +0800 Subject: [PATCH 219/231] fix(scan): array element is set to a zero value (#6890) * fix(scan): array element is set to a zero value * add test * fix test * optimization --- scan.go | 4 +++- tests/query_test.go | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 54cd6769..415b9f0d 100644 --- a/scan.go +++ b/scan.go @@ -274,7 +274,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - if !isArrayKind { + if isArrayKind { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } else { // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) diff --git a/tests/query_test.go b/tests/query_test.go index e780e3bf..c0259a14 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1409,3 +1409,22 @@ func TestQueryError(t *testing.T) { }, Value: 1}).Scan(&p2).Error AssertEqual(t, err, gorm.ErrModelValueRequired) } + +func TestQueryScanToArray(t *testing.T) { + err := DB.Create(&User{Name: "testname1", Age: 10}).Error + if err != nil { + t.Fatal(err) + } + + users := [2]*User{{Name: "1"}, {Name: "2"}} + err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error + if err != nil { + t.Fatal(err) + } + if users[0] == nil || users[0].Name != "testname1" { + t.Error("users[0] not covere") + } + if users[1] != nil { + t.Error("users[1] should be empty") + } +} From 281f3e369a644a9754bb6ba59cf2d554b3d98d19 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Mar 2024 11:32:30 +0800 Subject: [PATCH 220/231] Fix constraint name regexp --- schema/constraint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/constraint.go b/schema/constraint.go index 5f6beb89..0ed1eab0 100644 --- a/schema/constraint.go +++ b/schema/constraint.go @@ -8,7 +8,7 @@ import ( ) // reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") +var regEnLetterAndMidline = regexp.MustCompile("^[0-9A-Za-z-_]+$") type CheckConstraint struct { Name string From ab89d54d877ffee80a6681536e157b7f4a0fdf47 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Mon, 18 Mar 2024 13:44:55 +0800 Subject: [PATCH 221/231] chore: UnixNano convert to UnixMilli (#6907) --- callbacks/create.go | 2 +- callbacks/update.go | 4 ++-- schema/field.go | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b1488b08..210a46f7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -351,7 +351,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: - assignment.Value = curTime.UnixNano() / 1e6 + assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } diff --git a/callbacks/update.go b/callbacks/update.go index ff075dcf..7cde7f61 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { @@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 + value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { diff --git a/schema/field.go b/schema/field.go index 91e4c0ab..ca2e1148 100644 --- a/schema/field.go +++ b/schema/field.go @@ -664,7 +664,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -673,7 +673,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -738,7 +738,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } From f7ebf049dac5c1ff1e1648207071f91a0e950427 Mon Sep 17 00:00:00 2001 From: Jinghao Lu Date: Mon, 18 Mar 2024 13:48:42 +0800 Subject: [PATCH 222/231] fix(create): fix insert column order (#6855) * fix(create): fix insert column order * chore: add ConvertToCreateValues ut for Slice case * fix: remvoe testify dependency --------- Co-authored-by: lujinghao --- callbacks/create.go | 16 +++++---- callbacks/create_test.go | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 callbacks/create_test.go diff --git a/callbacks/create.go b/callbacks/create.go index 210a46f7..d930e922 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -293,13 +293,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } diff --git a/callbacks/create_test.go b/callbacks/create_test.go new file mode 100644 index 00000000..da6b172b --- /dev/null +++ b/callbacks/create_test.go @@ -0,0 +1,71 @@ +package callbacks + +import ( + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +var schemaCache = &sync.Map{} + +func TestConvertToCreateValues_DestType_Slice(t *testing.T) { + type user struct { + ID int `gorm:"primaryKey"` + Name string + Email string `gorm:"default:(-)"` + Age int `gorm:"default:(-)"` + } + + s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) + if err != nil { + t.Errorf("parse schema error: %v, is not expected", err) + return + } + dest := []*user{ + { + ID: 1, + Name: "alice", + Email: "email", + Age: 18, + }, + { + ID: 2, + Name: "bob", + Email: "email", + Age: 19, + }, + } + stmt := &gorm.Statement{ + DB: &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { return time.Time{} }, + }, + Statement: &gorm.Statement{ + Settings: sync.Map{}, + Schema: s, + }, + }, + ReflectValue: reflect.ValueOf(dest), + Dest: dest, + } + + stmt.Schema = s + + values := ConvertToCreateValues(stmt) + expected := clause.Values{ + // column has value + defaultValue column has value (which should have a stable order) + Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, + Values: [][]interface{}{ + {"alice", "email", 18, 1}, + {"bob", "email", 19, 2}, + }, + } + if !reflect.DeepEqual(expected, values) { + t.Errorf("expected: %v got %v", expected, values) + } +} From 303de6e7c89942b88f16d21830ec963de8ba53c3 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Mon, 18 Mar 2024 15:33:54 +0800 Subject: [PATCH 223/231] chore: optimize `regEnLetterAndMidline` regular (#6908) * chore: optimize regular * fix --- schema/constraint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/constraint.go b/schema/constraint.go index 0ed1eab0..80a743a8 100644 --- a/schema/constraint.go +++ b/schema/constraint.go @@ -8,7 +8,7 @@ import ( ) // reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[0-9A-Za-z-_]+$") +var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) type CheckConstraint struct { Name string From e0c3be03fba966a0128f4d215fa3743d0e8e9a6e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Mar 2024 16:28:46 +0800 Subject: [PATCH 224/231] Fix tests in local --- schema/serializer.go | 4 ++-- tests/docker-compose.yml | 1 + tests/go.mod | 17 +++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 397edff0..f500521e 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 866a4d62..8abd4d0f 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -24,6 +24,7 @@ services: ports: - "9930:1433" environment: + - TZ=Asia/Shanghai - ACCEPT_EULA=Y - SA_PASSWORD=LoremIpsum86 - MSSQL_DB=gorm diff --git a/tests/go.mod b/tests/go.mod index 136667b7..350152d3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,29 +6,30 @@ require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.8.4 - gorm.io/driver/mysql v1.5.4 - gorm.io/driver/postgres v1.5.6 + github.com/stretchr/testify v1.9.0 + gorm.io/driver/mysql v1.5.5 + gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.7 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/go-sql-driver/mysql v1.8.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.3 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/microsoft/go-mssqldb v1.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.18.0 // indirect + golang.org/x/crypto v0.21.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) From 1b0aa802dfa6df27be0c3f26e12d359e4afcd2ef Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 18 Mar 2024 19:24:16 +0800 Subject: [PATCH 225/231] Fix AutoMigrate for bool fields with default value --- migrator/migrator.go | 19 +++++++++++++------ tests/migrate_test.go | 12 +++++++++++- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index ae82f769..acce5df2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "regexp" + "strconv" "strings" "time" @@ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || - (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { - // default value not equal - // not both null - if currentDefaultNotNull || dvNotNull { - alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b25b9da6..d955c8d7 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "reflect" + "strconv" "strings" "testing" "time" @@ -1420,7 +1421,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { AssertEqual(t, nil, err) } -func TestMigrateDefaultNullString(t *testing.T) { +func TestMigrateWithDefaultValue(t *testing.T) { if DB.Dialector.Name() == "sqlserver" { // sqlserver driver treats NULL and 'NULL' the same t.Skip("skip sqlserver") @@ -1434,6 +1435,7 @@ func TestMigrateDefaultNullString(t *testing.T) { type NullStringModel struct { ID uint Content string `gorm:"default:'null'"` + Active bool `gorm:"default:false"` } tableName := "null_string_model" @@ -1454,6 +1456,14 @@ func TestMigrateDefaultNullString(t *testing.T) { AssertEqual(t, defVal, "null") AssertEqual(t, ok, true) + columnType2, err := findColumnType(tableName, "active") + AssertEqual(t, err, nil) + + defVal, ok = columnType2.DefaultValue() + bv, _ := strconv.ParseBool(defVal) + AssertEqual(t, bv, false) + AssertEqual(t, ok, true) + // default 'null' -> 'null' session := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, From 81536f823c055ba293dfbb7a8e90ebf93d32b431 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 Mar 2024 11:50:28 +0800 Subject: [PATCH 226/231] Fix insert id into map results, fix #6812 --- callbacks/create.go | 23 ++++++++++++++++------- tests/create_test.go | 30 ++++++++++++++---------------- tests/go.mod | 2 +- tests/helper_test.go | 4 ++++ 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index d930e922..afea2cca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { pkField *schema.Field pkFieldName = "@id" ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { + db.AddError(err) + } + return + } + if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { return @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { } } } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID diff --git a/tests/create_test.go b/tests/create_test.go index 5e97a542..abb82472 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) { } func TestCreateFromMapWithTable(t *testing.T) { - if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") - } - tableDB := DB.Table("`users`") + tableDB := DB.Table("users") + supportLastInsertID := isMysql() || isSqlite() // case 1: create from map[string]interface{} - record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} if err := tableDB.Create(record).Error; err != nil { t.Fatalf("failed to create data from map with table, got error: %v", err) } - if _, ok := record["@id"]; !ok { + if _, ok := record["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res["id"].(uint64)) != record["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { + t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) } // case 2: create from *map[string]interface{} @@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) { if err := tableDB2.Create(&record1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } - if _, ok := record1["@id"]; !ok { + if _, ok := record1["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res1["id"].(uint64)) != record1["@id"] { + if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { t.Fatal("failed to create data from map with table, @id != id") } @@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create data from slice of map, got error: %v", err) } - if _, ok := records[0]["@id"]; !ok { + if _, ok := records[0]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } - if _, ok := records[1]["@id"]; !ok { + if _, ok := records[1]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } - if int64(res2["id"].(uint64)) != records[0]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) } - if int64(res3["id"].(uint64)) != records[1]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id") } } diff --git a/tests/go.mod b/tests/go.mod index 350152d3..5616ebb3 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -11,7 +11,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.7 + gorm.io/gorm v1.25.8 ) require ( diff --git a/tests/helper_test.go b/tests/helper_test.go index feb67f9e..dc250b7c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -281,6 +281,10 @@ func isMysql() bool { return os.Getenv("GORM_DIALECT") == "mysql" } +func isSqlite() bool { + return os.Getenv("GORM_DIALECT") == "sqlite" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() From 57603882ea79f0f8bc1433ef77f94af51ed7cbb7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 Mar 2024 19:47:20 +0800 Subject: [PATCH 227/231] Only close bad conn prepared stmt --- prepare_stmt.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index aa944624..c60b5db7 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,8 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" "reflect" "sync" ) @@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -207,7 +209,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -222,7 +224,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() From 0d6c5345f3e4a60ffc28777cb3f4e7c6e94f9249 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 21 Mar 2024 15:55:43 +0800 Subject: [PATCH 228/231] Don't close prepared stmt for normal db error --- tests/go.mod | 2 +- tests/prepared_stmt_test.go | 27 --------------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 5616ebb3..3d3901d9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.9.0 - gorm.io/driver/mysql v1.5.5 + gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b234c8bf..b86bc3d6 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -126,33 +126,6 @@ func TestPreparedStmtDeadlock(t *testing.T) { AssertEqual(t, sqlDB.Stats().InUse, 0) } -func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection(&gorm.Config{}) - AssertEqual(t, err, nil) - - sqlDB, _ := tx.DB() - sqlDB.SetMaxOpenConns(1) - - tx = tx.Session(&gorm.Session{PrepareStmt: true}) - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - // err prepare - tag := Tag{Locale: "zh"} - tx.Table("users").Find(&tag) - wg.Done() - }() - } - wg.Wait() - - conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts), 0) - AssertEqual(t, sqlDB.Stats().InUse, 0) -} - func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"} From 956f7ce84309ccd5631b28335c814f4063eac58e Mon Sep 17 00:00:00 2001 From: givemeafish <981819494@qq.com> Date: Thu, 21 Mar 2024 16:00:02 +0800 Subject: [PATCH 229/231] fix: 'type XXXX int' will print wrong sql to terminal (#6917) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 王泽平 --- logger/sql.go | 19 +++++++++++++++++++ logger/sql_test.go | 30 +++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 8ce8d8b1..ad478795 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { diff --git a/logger/sql_test.go b/logger/sql_test.go index 036ef3a4..9002a7eb 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -37,14 +37,18 @@ func format(v []byte, escaper string) string { func TestExplainSQL(t *testing.T) { type role string type password []byte + type intType int + type floatType float64 var ( - tt = now.MustParse("2020-02-23 11:10:10") - myrole = role("admin") - pwd = password("pass") - jsVal = []byte(`{"Name":"test","Val":"test"}`) - js = JSON(jsVal) - esVal = []byte(`{"Name":"test","Val":"test"}`) - es = ExampleStruct{Name: "test", Val: "test"} + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password("pass") + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + intVal intType = 1 + floatVal floatType = 1.23 ) results := []struct { @@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, + }, } for idx, r := range results { From 26195e6d16cbb086423303d8178b78852ef12e2a Mon Sep 17 00:00:00 2001 From: snackmgmg <16898622+snackmgmg@users.noreply.github.com> Date: Tue, 26 Mar 2024 12:33:36 +0900 Subject: [PATCH 230/231] fix: remove `callback` from `callbacks` if `Remove()` called (#6916) * fix: remove callback from callbacks if Remove() called * reduce number of loops * remove unnecessary blank line --- callbacks.go | 19 ++++++++++++++++ tests/callbacks_test.go | 48 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index 195d1720..50b5b0e9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 4479da4c..f77209f1 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) { }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, + results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) { t.Errorf("callbacks tests failed, got %v", msg) } } + +func TestCallbacksGet(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { + t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) + } + + createCallback.Remove("c1") + if cb := createCallback.Get("c2"); cb != nil { + t.Errorf("callbacks test failed. got: %p, want: nil", cb) + } +} + +func TestCallbacksRemove(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + createCallback.After("*").Register("c2", c2) + createCallback.Before("c4").Register("c3", c3) + createCallback.After("c2").Register("c4", c4) + + // callbacks: []string{"c1", "c3", "c4", "c2"} + createCallback.Remove("c1") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c4") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c2") + if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c3") + if ok, msg := assertCallbacks(createCallback, []string{}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} From 1b48aa072d1210c2ba315aeea18b57fddb634875 Mon Sep 17 00:00:00 2001 From: "hjwblog.com" Date: Thu, 28 Mar 2024 16:47:39 +0800 Subject: [PATCH 231/231] feat: prepare_stmt support ping (#6924) * feat: prepare_stmt support ping * feat: prepare_stmt tx support ping --- prepare_stmt.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/prepare_stmt.go b/prepare_stmt.go index c60b5db7..4d533885 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -182,6 +182,14 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg return &sql.Row{} } +func (db *PreparedStmtDB) Ping() error { + conn, err := db.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} + type PreparedStmtTX struct { Tx PreparedStmtDB *PreparedStmtDB @@ -242,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg } return &sql.Row{} } + +func (tx *PreparedStmtTX) Ping() error { + conn, err := tx.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +}