From 9564b82975844e9e944aefc936968225d9857b86 Mon Sep 17 00:00:00 2001 From: Wen Sun Date: Fri, 7 Oct 2022 14:46:20 +0900 Subject: [PATCH 01/19] Fix OnConstraint builder (#5738) --- clause/on_conflict.go | 34 ++++++++++++++-------------- tests/postgres_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd..032bf4a1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index b5b672a9..f45b2618 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -148,3 +149,53 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} From 4b22a55a752d4284a72545a1611d651b364b3482 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Fri, 7 Oct 2022 18:29:28 +0800 Subject: [PATCH 02/19] fix: primaryFields are overwritten (#5721) --- schema/relationship.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index bb8aeb64..9436f283 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { reguessOrErr() return } + foreignFields = append(foreignFields, f) } } else { - var primaryFields []*Field var primarySchemaName = primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name @@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } - if len(foreignFields) == 0 { + switch { + case len(foreignFields) == 0: reguessOrErr() return - } else if len(relation.primaryKeys) > 0 { + case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { @@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu return } } - } else if len(primaryFields) == 0 { + case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { From e8f48b5c155b6fbf2e1fe6a554e2280f62af21a7 Mon Sep 17 00:00:00 2001 From: robhafner Date: Fri, 7 Oct 2022 08:14:14 -0400 Subject: [PATCH 03/19] fix: limit=0 results (#5735) (#5736) --- chainable_api.go | 2 +- clause/benchmarks_test.go | 3 ++- clause/limit.go | 10 +++++----- clause/limit_test.go | 20 ++++++++++++++------ finisher_api.go | 4 +++- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 68b4d1aa..ab3a1a32 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -244,7 +244,7 @@ func (db *DB) Order(value interface{}) (tx *DB) { // Limit specify the number of records to be retrieved func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Limit{Limit: limit}) + tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index e08677ac..34d5df41 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ @@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, - clause.Limit{Limit: 10, Offset: 20}, + clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/limit.go b/clause/limit.go index 184f6025..3ede7385 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -4,7 +4,7 @@ import "strconv" // Limit limit clause type Limit struct { - Limit int + Limit *int Offset int } @@ -15,12 +15,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(strconv.Itoa(*limit.Limit)) } if limit.Offset > 0 { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit != 0 { + if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index c26294aa..79065ab6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -8,6 +8,10 @@ import ( ) func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string @@ -15,11 +19,15 @@ func TestLimit(t *testing.T) { }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ - Limit: 10, + Limit: &limit10, Offset: 20, }}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET 20", nil, @@ -29,23 +37,23 @@ func TestLimit(t *testing.T) { "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, "SELECT * FROM `users` LIMIT 10", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, }, } diff --git a/finisher_api.go b/finisher_api.go index 835a6984..5516c0a1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { - totalSize = limit.Limit + if limit.Limit != nil { + totalSize = *limit.Limit + } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize From 34fbe84580290c32ba006b714669bb356224cb07 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Oct 2022 21:18:37 +0800 Subject: [PATCH 04/19] Add TableName with NamingStrategy support, close #5726 --- schema/schema.go | 7 +++++++ tests/go.mod | 12 +++++------- tests/table_test.go | 26 ++++++++++++++++++++++++++ utils/tests/dummy_dialecter.go | 10 +++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 42ff5c45..9b3d30f6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,6 +71,10 @@ type Tabler interface { TableName() string } +type TablerWithNamer interface { + TableName(Namer) string +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } diff --git a/tests/go.mod b/tests/go.mod index c1e1e0ce..d28c4bb9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect - gorm.io/driver/mysql v1.3.6 - gorm.io/driver/postgres v1.3.10 - gorm.io/driver/sqlite v1.3.6 - gorm.io/driver/sqlserver v1.3.2 + golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect + gorm.io/driver/mysql v1.4.0 + gorm.io/driver/postgres v1.4.1 + gorm.io/driver/sqlite v1.4.1 + gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.23.10 ) diff --git a/tests/table_test.go b/tests/table_test.go index 0289b7b8..f538c691 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -5,6 +5,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) @@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }}) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 2990c20f..c89b944a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -13,7 +14,14 @@ func (DummyDialector) Name() string { return "dummy" } -func (DummyDialector) Initialize(*gorm.DB) error { +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + return nil } From 983e96f14253c071b8ab3fb96b4c9f103ad39e1c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 16:04:57 +0800 Subject: [PATCH 05/19] Add tests for alter column type --- tests/go.mod | 4 ++-- tests/migrate_test.go | 2 +- tests/postgres_test.go | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index d28c4bb9..3919a838 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,10 +9,10 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.1 + gorm.io/driver/postgres v1.4.3 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 - gorm.io/gorm v1.23.10 + gorm.io/gorm v1.24.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 32e84e77..b918b4b5 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -400,7 +400,7 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index f45b2618..794ab8f7 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -8,6 +8,7 @@ import ( "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -199,3 +200,18 @@ func TestPostgresOnConstraint(t *testing.T) { t.Errorf("expected 1 thing got more") } } + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} From e93dc3426e8cb0a99091e2267ef2adf1cc86b4b5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 17:16:32 +0800 Subject: [PATCH 06/19] Test postgres autoincrement check --- tests/go.mod | 2 +- tests/postgres_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 3919a838..0160b2a6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.3 + gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.24.0 diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 794ab8f7..44cac6bf 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -112,6 +112,45 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } } type Post struct { From 2c56954cb12dd33fc8f1875a735091d61daff702 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 20:48:22 +0800 Subject: [PATCH 07/19] tests mariadb with returning support --- scan.go | 2 +- tests/connpool_test.go | 2 +- tests/go.mod | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 3a753dca..0a26ce4b 100644 --- a/scan.go +++ b/scan.go @@ -317,4 +317,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} \ No newline at end of file +} diff --git a/tests/connpool_test.go b/tests/connpool_test.go index fbae2294..42e029bc 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 0160b2a6..bf59e8d2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.0 + gorm.io/driver/mysql v1.4.1 gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 From 08aa2f9888dcd3c950943d09d0d7aaef1b1dcc33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 14 Oct 2022 20:30:28 +0800 Subject: [PATCH 08/19] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 312a3a59..5bb1be37 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) -* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing From aa4312ee74db5a23d459d487b43a4a79d341c936 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Oct 2022 15:57:10 +0800 Subject: [PATCH 09/19] Don't display any GORM related package path as source --- utils/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index 296917b9..90b4c8ea 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -16,7 +16,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "") } // FileWithLineNum return the file name and line number of the current file From 2a788fb20c3cbc73e96aa422b7477fe62d23964a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Oct 2022 17:01:00 +0800 Subject: [PATCH 10/19] Upgrade tests go.mod --- tests/go.mod | 10 +++++----- tests/sql_builder_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index bf59e8d2..2fef9d97 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,15 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.1 + golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect + golang.org/x/text v0.3.8 // indirect + gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.1 - gorm.io/driver/sqlserver v1.4.0 + gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a9b920dc..b10142fa 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } - date, _ := time.Parse("2006-01-02", "2021-10-18") + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { From 186e8a9e14578c63715444d217294065be072805 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 18 Oct 2022 11:58:42 +0800 Subject: [PATCH 11/19] fix: association without pks (#5779) --- callbacks/associations.go | 10 +++++++-- tests/associations_test.go | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 00e00fcc..9d7c1412 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + if isPtr { elems = reflect.Append(elems, elem) } else { @@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + distinctElems = reflect.Append(distinctElems, elem) } diff --git a/tests/associations_test.go b/tests/associations_test.go index 42b32afc..4c9076da 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) { AssertEqual(t, len(orgs), 0) } } + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} From ab5f80a8d81c1955e92224b24dfc9bc8c7d387a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 15:44:47 +0800 Subject: [PATCH 12/19] Save as NULL for nil object serialized into json --- schema/serializer.go | 3 +++ tests/go.mod | 4 ++-- tests/serializer_test.go | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 00a4f85f..fef39d9b 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -100,6 +100,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) + if string(result) == "null" { + return nil, err + } return string(result), err } diff --git a/tests/go.mod b/tests/go.mod index 2fef9d97..9c87ca34 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,10 +7,10 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect - golang.org/x/text v0.3.8 // indirect + golang.org/x/text v0.4.0 // indirect gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 946536bf..17bfefe2 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -18,6 +18,7 @@ type SerializerStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -108,7 +109,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From a0f4d3f7d207b2103b5f91e9758b1ac6a94056ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 16:25:39 +0800 Subject: [PATCH 13/19] Save as empty string for not nullable nil field serialized into json --- schema/serializer.go | 3 +++ tests/serializer_test.go | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/schema/serializer.go b/schema/serializer.go index fef39d9b..9a6aa4fc 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -101,6 +101,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } return nil, err } return string(result), err diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 17bfefe2..a040a4db 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,6 +19,7 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -109,7 +110,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From 62593cfad03ebf1e6cae30bac010655b4a28ff67 Mon Sep 17 00:00:00 2001 From: viatoriche / Maxim Panfilov Date: Tue, 18 Oct 2022 17:28:06 +0800 Subject: [PATCH 14/19] add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 --- migrator/migrator.go | 55 +++++++++++++++++++++---------------------- tests/migrate_test.go | 40 +++++++++++++++++++++++++++++++ tests/tracer_test.go | 34 ++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 28 deletions(-) create mode 100644 tests/tracer_test.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 29c0c00c..9f8e3db8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn, isSameType bool + ) if !field.PrimaryKey { // check type - var isSameType bool - if strings.HasPrefix(fullDataType, realDataType) { - isSameType = true - } - - // check type aliases - if !isSameType { + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases aliases := m.DB.Migrator().GetTypeAliases(realDataType) for _, alias := range aliases { if strings.HasPrefix(fullDataType, alias) { @@ -424,32 +421,34 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy break } } - } - if !isSameType { - alterColumn = true - } - } - - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if !field.PrimaryKey && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b918b4b5..8718aa57 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "fmt" "math/rand" "reflect" @@ -9,6 +10,7 @@ import ( "time" "gorm.io/driver/postgres" + "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -72,6 +74,44 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } + +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } } func TestAutoMigrateSelfReferential(t *testing.T) { diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} From 3f20a543fad5f57016ef7a6c342536b0fcce6016 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 18:01:55 +0800 Subject: [PATCH 15/19] Support use clause.Interface as query params --- statement.go | 4 ++++ tests/sql_builder_test.go | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/statement.go b/statement.go index cc26fe37..d05d299e 100644 --- a/statement.go +++ b/statement.go @@ -179,6 +179,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index b10142fa..0fbd6118 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) { if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. From 5dd2bb482755f5e8eb5ecaff39e675fb62f19a20 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 19 Oct 2022 14:46:59 +0800 Subject: [PATCH 16/19] feat(PreparedStmtDB): support reset (#5782) * feat(PreparedStmtDB): support reset * fix: close all stmt * test: fix test * fix: delete one by one --- prepare_stmt.go | 12 ++++++++++++ tests/prepared_stmt_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97..7591e533 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + for query, stmt := range db.Stmts { + delete(db.Stmts, query) + go stmt.Close() + } + + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2..64baa01b 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} From 9d82aa56734999bb28e0c4d60fba69ae7cde66d5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 20 Oct 2022 14:10:47 +0800 Subject: [PATCH 17/19] test: invalid cache plan with prepare stmt (#5778) * test: invalid cache plan with prepare stmt * test: more test cases * test: drop and rename column --- tests/migrate_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 8718aa57..96b1d0e4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "os" "reflect" "strings" "testing" @@ -12,6 +13,7 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -890,7 +892,7 @@ func findColumnType(dest interface{}, columnName string) ( return } -func TestInvalidCachedPlan(t *testing.T) { +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } @@ -925,6 +927,101 @@ func TestInvalidCachedPlan(t *testing.T) { } } +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { type DiffType struct { ID uint From b2f42528a48aeed9612d43e19cdf4fe8e87a27a3 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 2 Nov 2022 10:28:00 +0800 Subject: [PATCH 18/19] fix(Joins): args with select and omit (#5790) * fix(Joins): args with select and omit * chore: gofumpt style --- callbacks/query.go | 18 ++++++++++++----- chainable_api.go | 49 ++++++++++++++++++++++++++------------------- statement.go | 13 +++++++----- tests/joins_test.go | 43 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 31 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 26ee8c34..67936766 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } } exprs := make([]clause.Expression, len(relation.References)) diff --git a/chainable_api.go b/chainable_api.go index ab3a1a32..6d48d56b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -10,10 +10,11 @@ import ( ) // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value @@ -179,18 +180,21 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions -// db.Joins("Account").Find(&user) -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { + j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) - return + j.On = &where } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return } } @@ -219,8 +223,9 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { } // Order specify order when retrieve records from database -// db.Order("name DESC") -// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() @@ -256,17 +261,18 @@ func (db *DB) Offset(offset int) (tx *DB) { } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) @@ -274,7 +280,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { diff --git a/statement.go b/statement.go index d05d299e..d4d20cbf 100644 --- a/statement.go +++ b/statement.go @@ -49,9 +49,11 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string } // StatementModifier statement modifier interface @@ -544,8 +546,9 @@ func (stmt *Statement) clone() *Statement { } // SetColumn set column's value -// stmt.SetColumn("Name", "jinzhu") // Hooks Method -// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value diff --git a/tests/joins_test.go b/tests/joins_test.go index 7519db82..091fb986 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) { t.Fatalf("wrong pet name") } } + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} From f82e9cfdbed051e8e397e2fd1f7ab62c17ff8a4f Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Thu, 3 Nov 2022 21:03:13 +0800 Subject: [PATCH 19/19] test(clause/joins): add join unit test (#5832) --- clause/joins.go | 2 +- clause/joins_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 clause/joins_test.go diff --git a/clause/joins.go b/clause/joins.go index f3e373f2..879892be 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -9,7 +9,7 @@ const ( RightJoin JoinType = "RIGHT" ) -// Join join clause for from +// Join clause for from type Join struct { Type JoinType Table Table diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +}