From f22327938485f1673eab443949ae92367293c566 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 10 Aug 2022 11:03:42 +0800 Subject: [PATCH 01/56] chore: fix gorm tag (#5577) --- utils/tests/models.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/tests/models.go b/utils/tests/models.go index 22e8e659..ec1651a3 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -64,8 +64,8 @@ type Language struct { type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` - AmountOff uint32 `gorm:"amount_off"` - PercentOff float32 `gorm:"percent_off"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct { From a35883590b7f9467bedf43b9611b2c0d0ff30ffd Mon Sep 17 00:00:00 2001 From: Bruce MacKenzie Date: Wed, 10 Aug 2022 23:38:04 -0400 Subject: [PATCH 02/56] update Delete Godoc to describe soft delete behaviour (#5554) --- finisher_api.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index af9afb63..bdf0437d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -388,7 +388,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. +// If value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { From 573b9fa536050c156968b4d228cab05a119d78df Mon Sep 17 00:00:00 2001 From: enwawerueli Date: Fri, 12 Aug 2022 16:46:18 +0300 Subject: [PATCH 03/56] fix: correct grammar --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index c852e60c..1f1dac21 100644 --- a/gorm.go +++ b/gorm.go @@ -413,7 +413,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { - return fmt.Errorf("failed to found relation: %s", field) + return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { From ba227e8939d05f249a3ede8901193801d8da8603 Mon Sep 17 00:00:00 2001 From: Aoang Date: Mon, 15 Aug 2022 10:46:57 +0800 Subject: [PATCH 04/56] Add Go 1.19 Support (#5608) --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b97da3f4..367f4ccd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -86,7 +86,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From 3f92b9b0df84736750d6645e074596a7383ae089 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Mon, 15 Aug 2022 11:47:26 +0900 Subject: [PATCH 05/56] Refactor: redundant type from composite literal (#5604) --- statement_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/statement_test.go b/statement_test.go index a537c7be..761daf37 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,18 +37,18 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ - "table.name": []string{"table", "name"}, - "`table`.`name`": []string{"table", "name"}, - "'table'.'name'": []string{"table", "name"}, - "'table'.name": []string{"table", "name"}, - "table1.name_23": []string{"table1", "name_23"}, - "`table_1`.`name23`": []string{"table_1", "name23"}, - "'table23'.'name_1'": []string{"table23", "name_1"}, - "'table23'.name1": []string{"table23", "name1"}, - "'name1'": []string{"", "name1"}, - "`name_1`": []string{"", "name_1"}, - "`Name_1`": []string{"", "Name_1"}, - "`Table`.`nAme`": []string{"Table", "nAme"}, + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) From 8c3018b96aea241a35b769291de6edd2a3378b44 Mon Sep 17 00:00:00 2001 From: Shunsuke Otani Date: Mon, 15 Aug 2022 11:50:06 +0900 Subject: [PATCH 06/56] Replace `ioutil.Discard` with `io.Discard` (#5603) --- go.mod | 2 +- logger/logger.go | 6 +++--- tests/go.mod | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 57362745..03f84379 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.14 +go 1.16 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..ce088561 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "os" "time" @@ -68,8 +68,8 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, diff --git a/tests/go.mod b/tests/go.mod index eb8f336d..19280434 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.14 +go 1.16 require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect From d71caef7d9d08287971a129bc19068eb1f48ed8f Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 3 Sep 2022 20:00:21 +0800 Subject: [PATCH 07/56] fix: remove uuid autoincrement (#5620) --- tests/postgres_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 66b988c3..97af6db3 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -63,13 +63,13 @@ func TestPostgres(t *testing.T) { } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } From f78f635fae6f332a76e8f3e38d939864d1f5c209 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Mon, 5 Sep 2022 15:34:33 +0800 Subject: [PATCH 08/56] Optimize: code logic db.scanIntoStruct() (#5633) --- scan.go | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/scan.go b/scan.go index 6250fb57..2db43160 100644 --- a/scan.go +++ b/scan.go @@ -66,30 +66,32 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}, 0) + joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { - if field != nil { - if len(joinFields) == 0 || joinFields[idx][0] == nil { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil - } - } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) - } - - // release data to pool - field.NewValuePool.Put(values[idx]) + if field == nil { + continue } + + if len(joinFields) == 0 || joinFields[idx][0] == nil { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } + } + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) + } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } From b3eb1c8c512430c1600f720a96b2af777c91d1da Mon Sep 17 00:00:00 2001 From: Jiepeng Cao Date: Mon, 5 Sep 2022 15:39:19 +0800 Subject: [PATCH 09/56] simplified regexp (#5677) --- migrator/migrator.go | 2 +- statement.go | 2 +- tests/upsert_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 87ac7745..c1d7e0e7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,7 +15,7 @@ import ( ) var ( - regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) ) // Migrator m struct diff --git a/statement.go b/statement.go index 12687810..cc26fe37 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f90c4518..e84dc14a 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } From f29afdd3297d94b3e789e1f8d0ab8c823325eba5 Mon Sep 17 00:00:00 2001 From: Bruce MacKenzie Date: Thu, 8 Sep 2022 23:16:41 -0400 Subject: [PATCH 10/56] Rewrite of finisher_api Godocs (#5618) --- finisher_api.go | 49 +++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index bdf0437d..835a6984 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -13,7 +13,7 @@ import ( "gorm.io/gorm/utils" ) -// Create insert the value into database +// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) @@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { return tx.callbacks.Create().Execute(tx) } -// CreateInBatches insert the value in batches into database +// CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { return } -// First find first record that match given conditions, order by primary key +// First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Last find last record that match given conditions, order by primary key +// Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Find find records that match given conditions +// Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// FindInBatches find records in batches +// FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ @@ -286,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { } } -// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -312,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } -// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ @@ -360,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx } -// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } -// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values @@ -388,8 +390,8 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. -// If value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current // time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() @@ -484,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return rows, tx.Error } -// Scan scan value to a struct +// Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() @@ -509,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -// Pluck used to query single column from a model as a map +// Pluck queries a single column from a model, returning in the slice dest. E.g.: // var ages []int64 // db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { @@ -552,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } -// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error @@ -574,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true @@ -617,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } -// Begin begins a transaction +// Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement @@ -646,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { return tx } -// Commit commit a transaction +// Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) @@ -656,7 +661,7 @@ func (db *DB) Commit() *DB { return db } -// Rollback rollback a transaction +// Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { @@ -686,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB { return db } -// Exec execute raw sql +// Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} From edb00c10adff38445c4350c0cb524faa6ec2d592 Mon Sep 17 00:00:00 2001 From: Googol Lee Date: Wed, 14 Sep 2022 04:26:51 +0200 Subject: [PATCH 11/56] AutoMigrate() should always migrate checks, even there is no relationship constraints. (#5644) * fix: remove uuid autoincrement * AutoMigrate() should always migrate checks, even there is no relationship constranits. Co-authored-by: a631807682 <631807682@qq.com> --- migrator/migrator.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c1d7e0e7..e6782a13 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -135,12 +135,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + } - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { - return err - } + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err } } } From 490625981a1c3474eeca7f2e4fde791cd94c84fa Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:02:44 +0800 Subject: [PATCH 12/56] fix: update omit (#5699) --- callbacks/update.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 48c61bf4..b596df9a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else if _, ok := db.Statement.Clauses["SET"]; !ok { - return + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } } db.Statement.Build(db.Statement.BuildClauses...) From 5ed7b1a65e2aeeb92bb12f2b1ebcac2e4d3402fe Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 11:25:03 +0800 Subject: [PATCH 13/56] fix: same embedded filed name (#5705) --- migrator/migrator.go | 2 +- tests/migrate_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index e6782a13..d7ebf276 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -478,7 +478,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + return m.DB.Migrator().AlterColumn(value, field.DBName) } return nil diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0b5bc5eb..32e84e77 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -959,3 +959,41 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + AssertEqual(t, nil, err) +} From 1f634c39377f914187ae9efb1bc1bdbc94e97028 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Thu, 22 Sep 2022 14:50:35 +0800 Subject: [PATCH 14/56] support scan assign slice cap (#5634) * support scan assign slice cap * fix --- scan.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 2db43160..df5a3714 100644 --- a/scan.go +++ b/scan.go @@ -248,7 +248,13 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } for initialized || rows.Next() { From 3a72ba102ec1ce729f703be4ac00e0049b82b0e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 21 Sep 2022 17:29:38 +0800 Subject: [PATCH 15/56] Allow shared foreign key for many2many jointable --- schema/relationship.go | 60 ++++++++++++++++++++++--------------- schema/relationship_test.go | 29 +++++++++++++++++- tests/go.mod | 13 ++++---- 3 files changed, 71 insertions(+), 31 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 0aa33e51..bb8aeb64 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, @@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), - "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..85c45589 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/tests/go.mod b/tests/go.mod index 19280434..ebebabc0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,18 @@ module gorm.io/gorm/tests go 1.16 require ( + github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.14 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.5 - gorm.io/driver/postgres v1.3.8 + github.com/lib/pq v1.10.7 + github.com/mattn/go-sqlite3 v1.14.15 // indirect + golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + gorm.io/driver/mysql v1.3.6 + gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + gorm.io/gorm v1.23.9 ) replace gorm.io/gorm => ../ From 101a7c789fa2c41f409da439056806756fd8ce22 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 15:51:47 +0800 Subject: [PATCH 16/56] fix: scan array (#5624) Co-authored-by: Jinzhu --- scan.go | 22 +++++++++++++++------- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/scan.go b/scan.go index df5a3714..70cd4284 100644 --- a/scan.go +++ b/scan.go @@ -243,15 +243,18 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value - recyclableStruct := reflect.New(reflectValueType) + var ( + elem reflect.Value + recyclableStruct = reflect.New(reflectValueType) + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else { + } else if !isArrayKind { reflectValue.SetLen(0) db.Statement.ReflectValue.Set(reflectValue) } @@ -285,10 +288,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } @@ -312,4 +320,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} +} \ No newline at end of file diff --git a/tests/query_test.go b/tests/query_test.go index 4569fe1a..eccf0133 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -216,6 +216,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) From 73bc53f061ee1f54b9ef562a3466b5e3c5438aea Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 22 Sep 2022 15:56:32 +0800 Subject: [PATCH 17/56] feat: migrator support type aliases (#5627) * feat: migrator support type aliases * perf: check type --- migrator.go | 1 + migrator/migrator.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/migrator.go b/migrator.go index 34e888f2..882fc4cc 100644 --- a/migrator.go +++ b/migrator.go @@ -68,6 +68,7 @@ type Migrator interface { // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index d7ebf276..29c0c00c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -408,9 +408,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn := false - // check type - if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { - alterColumn = true + if !field.PrimaryKey { + // check type + var isSameType bool + if strings.HasPrefix(fullDataType, realDataType) { + isSameType = true + } + + // check type aliases + if !isSameType { + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + } + + if !isSameType { + alterColumn = true + } } // check size @@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} From 12237454ed695461eb750aee9fca6bac7faa8b8b Mon Sep 17 00:00:00 2001 From: kinggo Date: Thu, 22 Sep 2022 16:47:31 +0800 Subject: [PATCH 18/56] fix: use preparestmt in trasaction will use new conn, close #5508 --- gorm.go | 16 ++++++++++++---- tests/prepared_stmt_test.go | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/gorm.go b/gorm.go index 1f1dac21..81b6e2af 100644 --- a/gorm.go +++ b/gorm.go @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..86e3630d 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "errors" "testing" "time" @@ -88,3 +89,19 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} From 328f3019825c95be6264cc94d3b4c32fe3cf61d1 Mon Sep 17 00:00:00 2001 From: Nguyen Huu Tuan <54979794+nohattee@users.noreply.github.com> Date: Thu, 22 Sep 2022 17:35:21 +0700 Subject: [PATCH 19/56] add some test case which related the logic (#5477) --- schema/schema.go | 8 +++++++ tests/postgres_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/schema/schema.go b/schema/schema.go index 3791237d..42ff5c45 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -239,6 +239,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.HasDefaultValue = true field.AutoIncrement = true } + case String: + if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + } } } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 97af6db3..b5b672a9 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -9,6 +9,56 @@ import ( "gorm.io/gorm" ) +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() From e1dd0dcbc41741e94702d0973df88f4a7afd98e1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 30 Sep 2022 11:13:01 +0800 Subject: [PATCH 20/56] chore(deps): bump actions/stale from 5 to 6 (#5717) Bumps [actions/stale](https://github.com/actions/stale) from 5 to 6. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index aa1812d4..bc4487ae 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index c3c92beb..f9f51aa0 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af8d3636..a9aff43a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From be440e75122de5f7c19e2242a59246a92ce8edfe Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Fri, 30 Sep 2022 11:14:34 +0800 Subject: [PATCH 21/56] fix possible nil panic in tests (#5720) * fix maybe nil panic * reset code --- tests/callbacks_test.go | 3 +++ tests/transaction_test.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 2bf9496b..4479da4c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) { for idx, data := range datas { db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } callbacks := db.Callback() for _, c := range data.callbacks { diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 0ac04a04..5872da94 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) { return errors.New("the error message") }) - if err.Error() != "the error message" { + if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } From a3cc6c6088c1e2aa8cbd174f4714e7fc6d0acd59 Mon Sep 17 00:00:00 2001 From: Stephano George Date: Fri, 30 Sep 2022 17:18:42 +0800 Subject: [PATCH 22/56] Fix: wrong value when Find with Join with same column name, close #5723, #5711 --- scan.go | 31 ++++++++++++++----------------- tests/go.mod | 4 ++-- tests/joins_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/scan.go b/scan.go index 70cd4284..3a753dca 100644 --- a/scan.go +++ b/scan.go @@ -163,11 +163,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - selectedColumnsMap = make(map[string]int, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -200,26 +199,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { - schFieldsCount := len(sch.Fields) + matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - if curIndex, ok := selectedColumnsMap[column]; ok { - fields[idx] = field // handle duplicate fields - offset := curIndex + 1 - // handle sch inconsistent with database - // like Raw(`...`).Scan - if schFieldsCount > offset { - for fieldIndex, selectField := range sch.Fields[offset:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ fields[idx] = selectField break } + count-- } } } else { - fields[idx] = field - selectedColumnsMap[column] = idx + matchedFieldCount[column] = 1 } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { diff --git a/tests/go.mod b/tests/go.mod index ebebabc0..c1e1e0ce 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,12 +9,12 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect + golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect gorm.io/driver/mysql v1.3.6 gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.9 + gorm.io/gorm v1.23.10 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 4908e5ba..7519db82 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -229,3 +229,34 @@ func TestJoinWithSoftDeleted(t *testing.T) { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} From 0b7113b618584edd76d74e7a73eecc2a28a4d17a Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 30 Sep 2022 18:13:36 +0800 Subject: [PATCH 23/56] fix: prepare deadlock (#5568) * fix: prepare deadlock * chore[ci skip]: code style * chore[ci skip]: test remove unnecessary params * fix: prepare deadlock * fix: double check prepare * test: more goroutines * chore[ci skip]: improve code comments Co-authored-by: Jinzhu --- gorm.go | 2 +- prepare_stmt.go | 54 ++++++++++++++++++++++++------- tests/prepared_stmt_test.go | 63 +++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 12 deletions(-) diff --git a/gorm.go b/gorm.go index 81b6e2af..589fc4ff 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string]Stmt{}, + Stmts: map[string](*Stmt){}, Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index b062b0d6..3934bb97 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,10 +9,12 @@ import ( type Stmt struct { *sql.Stmt Transaction bool + prepared chan struct{} + prepareErr error } type PreparedStmtDB struct { - Stmts map[string]Stmt + Stmts map[string]*Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool @@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - return stmt, nil + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } db.Mux.RUnlock() db.Mux.Lock() - defer db.Mux.Unlock() - // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - return stmt, nil - } else if ok { - go stmt.Close() + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) - if err == nil { - db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} - db.PreparedSQL = append(db.PreparedSQL, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err } - return db.Stmts[query], err + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 86e3630d..c7f251f2 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,7 @@ package tests_test import ( "context" + "sync" "errors" "testing" "time" @@ -90,6 +91,68 @@ func TestPreparedStmtFromTransaction(t *testing.T) { tx2.Commit() } +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtError(t *testing.T) { + tx, err := OpenTestConnection() + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + // err prepare + tag := Tag{Locale: "zh"} + tx.Table("users").Find(&tag) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 0) + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"} From 9564b82975844e9e944aefc936968225d9857b86 Mon Sep 17 00:00:00 2001 From: Wen Sun Date: Fri, 7 Oct 2022 14:46:20 +0900 Subject: [PATCH 24/56] Fix OnConstraint builder (#5738) --- clause/on_conflict.go | 34 ++++++++++++++-------------- tests/postgres_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd..032bf4a1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index b5b672a9..f45b2618 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -148,3 +149,53 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} From 4b22a55a752d4284a72545a1611d651b364b3482 Mon Sep 17 00:00:00 2001 From: "jesse.tang" <1430482733@qq.com> Date: Fri, 7 Oct 2022 18:29:28 +0800 Subject: [PATCH 25/56] fix: primaryFields are overwritten (#5721) --- schema/relationship.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index bb8aeb64..9436f283 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { reguessOrErr() return } + foreignFields = append(foreignFields, f) } } else { - var primaryFields []*Field var primarySchemaName = primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name @@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } - if len(foreignFields) == 0 { + switch { + case len(foreignFields) == 0: reguessOrErr() return - } else if len(relation.primaryKeys) > 0 { + case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { @@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu return } } - } else if len(primaryFields) == 0 { + case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { From e8f48b5c155b6fbf2e1fe6a554e2280f62af21a7 Mon Sep 17 00:00:00 2001 From: robhafner Date: Fri, 7 Oct 2022 08:14:14 -0400 Subject: [PATCH 26/56] fix: limit=0 results (#5735) (#5736) --- chainable_api.go | 2 +- clause/benchmarks_test.go | 3 ++- clause/limit.go | 10 +++++----- clause/limit_test.go | 20 ++++++++++++++------ finisher_api.go | 4 +++- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 68b4d1aa..ab3a1a32 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -244,7 +244,7 @@ func (db *DB) Order(value interface{}) (tx *DB) { // Limit specify the number of records to be retrieved func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Limit{Limit: limit}) + tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index e08677ac..34d5df41 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ @@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, - clause.Limit{Limit: 10, Offset: 20}, + clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/limit.go b/clause/limit.go index 184f6025..3ede7385 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -4,7 +4,7 @@ import "strconv" // Limit limit clause type Limit struct { - Limit int + Limit *int Offset int } @@ -15,12 +15,12 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(limit.Limit)) + builder.WriteString(strconv.Itoa(*limit.Limit)) } if limit.Offset > 0 { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit != 0 { + if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index c26294aa..79065ab6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -8,6 +8,10 @@ import ( ) func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string @@ -15,11 +19,15 @@ func TestLimit(t *testing.T) { }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ - Limit: 10, + Limit: &limit10, Offset: 20, }}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT 0", nil, + }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET 20", nil, @@ -29,23 +37,23 @@ func TestLimit(t *testing.T) { "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, "SELECT * FROM `users` LIMIT 10", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, "SELECT * FROM `users` OFFSET 30", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, }, } diff --git a/finisher_api.go b/finisher_api.go index 835a6984..5516c0a1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { - totalSize = limit.Limit + if limit.Limit != nil { + totalSize = *limit.Limit + } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize From 34fbe84580290c32ba006b714669bb356224cb07 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 7 Oct 2022 21:18:37 +0800 Subject: [PATCH 27/56] Add TableName with NamingStrategy support, close #5726 --- schema/schema.go | 7 +++++++ tests/go.mod | 12 +++++------- tests/table_test.go | 26 ++++++++++++++++++++++++++ utils/tests/dummy_dialecter.go | 10 +++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 42ff5c45..9b3d30f6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -71,6 +71,10 @@ type Tabler interface { TableName() string } +type TablerWithNamer interface { + TableName(Namer) string +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } diff --git a/tests/go.mod b/tests/go.mod index c1e1e0ce..d28c4bb9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect - gorm.io/driver/mysql v1.3.6 - gorm.io/driver/postgres v1.3.10 - gorm.io/driver/sqlite v1.3.6 - gorm.io/driver/sqlserver v1.3.2 + golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect + gorm.io/driver/mysql v1.4.0 + gorm.io/driver/postgres v1.4.1 + gorm.io/driver/sqlite v1.4.1 + gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.23.10 ) diff --git a/tests/table_test.go b/tests/table_test.go index 0289b7b8..f538c691 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -5,6 +5,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) @@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }}) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 2990c20f..c89b944a 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -2,6 +2,7 @@ package tests import ( "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -13,7 +14,14 @@ func (DummyDialector) Name() string { return "dummy" } -func (DummyDialector) Initialize(*gorm.DB) error { +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + return nil } From 983e96f14253c071b8ab3fb96b4c9f103ad39e1c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 16:04:57 +0800 Subject: [PATCH 28/56] Add tests for alter column type --- tests/go.mod | 4 ++-- tests/migrate_test.go | 2 +- tests/postgres_test.go | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index d28c4bb9..3919a838 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,10 +9,10 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.1 + gorm.io/driver/postgres v1.4.3 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 - gorm.io/gorm v1.23.10 + gorm.io/gorm v1.24.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 32e84e77..b918b4b5 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -400,7 +400,7 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index f45b2618..794ab8f7 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -8,6 +8,7 @@ import ( "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { @@ -199,3 +200,18 @@ func TestPostgresOnConstraint(t *testing.T) { t.Errorf("expected 1 thing got more") } } + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} From e93dc3426e8cb0a99091e2267ef2adf1cc86b4b5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 17:16:32 +0800 Subject: [PATCH 29/56] Test postgres autoincrement check --- tests/go.mod | 2 +- tests/postgres_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/go.mod b/tests/go.mod index 3919a838..0160b2a6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect gorm.io/driver/mysql v1.4.0 - gorm.io/driver/postgres v1.4.3 + gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 gorm.io/gorm v1.24.0 diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 794ab8f7..44cac6bf 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -112,6 +112,45 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } } type Post struct { From 2c56954cb12dd33fc8f1875a735091d61daff702 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 8 Oct 2022 20:48:22 +0800 Subject: [PATCH 30/56] tests mariadb with returning support --- scan.go | 2 +- tests/connpool_test.go | 2 +- tests/go.mod | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 3a753dca..0a26ce4b 100644 --- a/scan.go +++ b/scan.go @@ -317,4 +317,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} \ No newline at end of file +} diff --git a/tests/connpool_test.go b/tests/connpool_test.go index fbae2294..42e029bc 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 0160b2a6..bf59e8d2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.0 + gorm.io/driver/mysql v1.4.1 gorm.io/driver/postgres v1.4.4 gorm.io/driver/sqlite v1.4.1 gorm.io/driver/sqlserver v1.4.0 From 08aa2f9888dcd3c950943d09d0d7aaef1b1dcc33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 14 Oct 2022 20:30:28 +0800 Subject: [PATCH 31/56] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 312a3a59..5bb1be37 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) -* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing From aa4312ee74db5a23d459d487b43a4a79d341c936 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Oct 2022 15:57:10 +0800 Subject: [PATCH 32/56] Don't display any GORM related package path as source --- utils/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/utils.go b/utils/utils.go index 296917b9..90b4c8ea 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -16,7 +16,7 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "") } // FileWithLineNum return the file name and line number of the current file From 2a788fb20c3cbc73e96aa422b7477fe62d23964a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Oct 2022 17:01:00 +0800 Subject: [PATCH 33/56] Upgrade tests go.mod --- tests/go.mod | 10 +++++----- tests/sql_builder_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index bf59e8d2..2fef9d97 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,15 +3,15 @@ module gorm.io/gorm/tests go 1.16 require ( - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b // indirect - gorm.io/driver/mysql v1.4.1 + golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect + golang.org/x/text v0.3.8 // indirect + gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.1 - gorm.io/driver/sqlserver v1.4.0 + gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a9b920dc..b10142fa 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } - date, _ := time.Parse("2006-01-02", "2021-10-18") + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { From 186e8a9e14578c63715444d217294065be072805 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Tue, 18 Oct 2022 11:58:42 +0800 Subject: [PATCH 34/56] fix: association without pks (#5779) --- callbacks/associations.go | 10 +++++++-- tests/associations_test.go | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 00e00fcc..9d7c1412 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + if isPtr { elems = reflect.Append(elems, elem) } else { @@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + distinctElems = reflect.Append(distinctElems, elem) } diff --git a/tests/associations_test.go b/tests/associations_test.go index 42b32afc..4c9076da 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) { AssertEqual(t, len(orgs), 0) } } + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} From ab5f80a8d81c1955e92224b24dfc9bc8c7d387a0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 15:44:47 +0800 Subject: [PATCH 35/56] Save as NULL for nil object serialized into json --- schema/serializer.go | 3 +++ tests/go.mod | 4 ++-- tests/serializer_test.go | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 00a4f85f..fef39d9b 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -100,6 +100,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) + if string(result) == "null" { + return nil, err + } return string(result), err } diff --git a/tests/go.mod b/tests/go.mod index 2fef9d97..9c87ca34 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,10 +7,10 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect - golang.org/x/text v0.3.8 // indirect + golang.org/x/text v0.4.0 // indirect gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.4.4 - gorm.io/driver/sqlite v1.4.2 + gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlserver v1.4.1 gorm.io/gorm v1.24.0 ) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 946536bf..17bfefe2 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -18,6 +18,7 @@ type SerializerStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -108,7 +109,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From a0f4d3f7d207b2103b5f91e9758b1ac6a94056ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 16:25:39 +0800 Subject: [PATCH 36/56] Save as empty string for not nullable nil field serialized into json --- schema/serializer.go | 3 +++ tests/serializer_test.go | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/schema/serializer.go b/schema/serializer.go index fef39d9b..9a6aa4fc 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -101,6 +101,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } return nil, err } return string(result), err diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 17bfefe2..a040a4db 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,6 +19,7 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type @@ -109,7 +110,7 @@ func TestSerializer(t *testing.T) { } var result SerializerStruct - if err := DB.Where("roles2 IS NULL").First(&result, data.ID).Error; err != nil { + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } From 62593cfad03ebf1e6cae30bac010655b4a28ff67 Mon Sep 17 00:00:00 2001 From: viatoriche / Maxim Panfilov Date: Tue, 18 Oct 2022 17:28:06 +0800 Subject: [PATCH 37/56] add test: TestAutoMigrateInt8PG: shouldn't execute ALTER COLUMN TYPE smallint, close #5762 --- migrator/migrator.go | 55 +++++++++++++++++++++---------------------- tests/migrate_test.go | 40 +++++++++++++++++++++++++++++++ tests/tracer_test.go | 34 ++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 28 deletions(-) create mode 100644 tests/tracer_test.go diff --git a/migrator/migrator.go b/migrator/migrator.go index 29c0c00c..9f8e3db8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn, isSameType bool + ) if !field.PrimaryKey { // check type - var isSameType bool - if strings.HasPrefix(fullDataType, realDataType) { - isSameType = true - } - - // check type aliases - if !isSameType { + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases aliases := m.DB.Migrator().GetTypeAliases(realDataType) for _, alias := range aliases { if strings.HasPrefix(fullDataType, alias) { @@ -424,32 +421,34 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy break } } - } - if !isSameType { - alterColumn = true - } - } - - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if !field.PrimaryKey && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b918b4b5..8718aa57 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "fmt" "math/rand" "reflect" @@ -9,6 +10,7 @@ import ( "time" "gorm.io/driver/postgres" + "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" @@ -72,6 +74,44 @@ func TestMigrate(t *testing.T) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } + +} + +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } } func TestAutoMigrateSelfReferential(t *testing.T) { diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} From 3f20a543fad5f57016ef7a6c342536b0fcce6016 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Oct 2022 18:01:55 +0800 Subject: [PATCH 38/56] Support use clause.Interface as query params --- statement.go | 4 ++++ tests/sql_builder_test.go | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/statement.go b/statement.go index cc26fe37..d05d299e 100644 --- a/statement.go +++ b/statement.go @@ -179,6 +179,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index b10142fa..0fbd6118 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) { if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. From 5dd2bb482755f5e8eb5ecaff39e675fb62f19a20 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 19 Oct 2022 14:46:59 +0800 Subject: [PATCH 39/56] feat(PreparedStmtDB): support reset (#5782) * feat(PreparedStmtDB): support reset * fix: close all stmt * test: fix test * fix: delete one by one --- prepare_stmt.go | 12 ++++++++++++ tests/prepared_stmt_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97..7591e533 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + for query, stmt := range db.Stmts { + delete(db.Stmts, query) + go stmt.Close() + } + + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2..64baa01b 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} From 9d82aa56734999bb28e0c4d60fba69ae7cde66d5 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 20 Oct 2022 14:10:47 +0800 Subject: [PATCH 40/56] test: invalid cache plan with prepare stmt (#5778) * test: invalid cache plan with prepare stmt * test: more test cases * test: drop and rename column --- tests/migrate_test.go | 99 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 8718aa57..96b1d0e4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "math/rand" + "os" "reflect" "strings" "testing" @@ -12,6 +13,7 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -890,7 +892,7 @@ func findColumnType(dest interface{}, columnName string) ( return } -func TestInvalidCachedPlan(t *testing.T) { +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } @@ -925,6 +927,101 @@ func TestInvalidCachedPlan(t *testing.T) { } } +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { type DiffType struct { ID uint From b2f42528a48aeed9612d43e19cdf4fe8e87a27a3 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 2 Nov 2022 10:28:00 +0800 Subject: [PATCH 41/56] fix(Joins): args with select and omit (#5790) * fix(Joins): args with select and omit * chore: gofumpt style --- callbacks/query.go | 18 ++++++++++++----- chainable_api.go | 49 ++++++++++++++++++++++++++------------------- statement.go | 13 +++++++----- tests/joins_test.go | 43 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 31 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 26ee8c34..67936766 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } } exprs := make([]clause.Expression, len(relation.References)) diff --git a/chainable_api.go b/chainable_api.go index ab3a1a32..6d48d56b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -10,10 +10,11 @@ import ( ) // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value @@ -179,18 +180,21 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions -// db.Joins("Account").Find(&user) -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { + j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) - return + j.On = &where } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return } } @@ -219,8 +223,9 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { } // Order specify order when retrieve records from database -// db.Order("name DESC") -// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() @@ -256,17 +261,18 @@ func (db *DB) Offset(offset int) (tx *DB) { } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) @@ -274,7 +280,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { diff --git a/statement.go b/statement.go index d05d299e..d4d20cbf 100644 --- a/statement.go +++ b/statement.go @@ -49,9 +49,11 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string } // StatementModifier statement modifier interface @@ -544,8 +546,9 @@ func (stmt *Statement) clone() *Statement { } // SetColumn set column's value -// stmt.SetColumn("Name", "jinzhu") // Hooks Method -// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value diff --git a/tests/joins_test.go b/tests/joins_test.go index 7519db82..091fb986 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) { t.Fatalf("wrong pet name") } } + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} From f82e9cfdbed051e8e397e2fd1f7ab62c17ff8a4f Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Thu, 3 Nov 2022 21:03:13 +0800 Subject: [PATCH 42/56] test(clause/joins): add join unit test (#5832) --- clause/joins.go | 2 +- clause/joins_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 clause/joins_test.go diff --git a/clause/joins.go b/clause/joins.go index f3e373f2..879892be 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -9,7 +9,7 @@ const ( RightJoin JoinType = "RIGHT" ) -// Join join clause for from +// Join clause for from type Join struct { Type JoinType Table Table diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +} From 5c8ecc3a2ad2aa570ecc0bb947138539a1bad9cf Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sat, 5 Nov 2022 08:37:37 +0800 Subject: [PATCH 43/56] feat: golangci add goimports and whitespace (#5835) --- .golangci.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.golangci.yml b/.golangci.yml index 16903ed6..b88bf672 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,3 +9,12 @@ linters: - prealloc - unconvert - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + From fb640cf7daee5a4c6b738299a711612624112de7 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sat, 5 Nov 2022 08:38:14 +0800 Subject: [PATCH 44/56] test(utils): add utils unit test (#5834) --- utils/utils_test.go | 106 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/utils/utils_test.go b/utils/utils_test.go index 27dfee16..71eef964 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,8 +1,13 @@ package utils import ( + "database/sql" + "database/sql/driver" + "errors" + "math" "strings" "testing" + "time" ) func TestIsValidDBNameChar(t *testing.T) { @@ -13,6 +18,29 @@ func TestIsValidDBNameChar(t *testing.T) { } } +func TestCheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + if out := CheckTruth(test.v); out != test.out { + t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) + } + }) + } +} + func TestToStringKey(t *testing.T) { cases := []struct { values []interface{} @@ -29,3 +57,81 @@ func TestToStringKey(t *testing.T) { } } } + +func TestContains(t *testing.T) { + containsTests := []struct { + name string + elems []string + elem string + out bool + }{ + {"exists", []string{"1", "2", "3"}, "1", true}, + {"not exists", []string{"1", "2", "3"}, "4", false}, + } + for _, test := range containsTests { + t.Run(test.name, func(t *testing.T) { + if out := Contains(test.elems, test.elem); test.out != out { + t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) + } + }) + } +} + +type ModifyAt sql.NullTime + +// Value return a Unix time. +func (n ModifyAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time.Unix(), nil +} + +func TestAssertEqual(t *testing.T) { + now := time.Now() + assertEqualTests := []struct { + name string + src, dst interface{} + out bool + }{ + {"error equal", errors.New("1"), errors.New("1"), true}, + {"error not equal", errors.New("1"), errors.New("2"), false}, + {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, + {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + } + for _, test := range assertEqualTests { + t.Run(test.name, func(t *testing.T) { + if out := AssertEqual(test.src, test.dst); test.out != out { + t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) + } + }) + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + in interface{} + out string + }{ + {"int", math.MaxInt64, "9223372036854775807"}, + {"int8", int8(math.MaxInt8), "127"}, + {"int16", int16(math.MaxInt16), "32767"}, + {"int32", int32(math.MaxInt32), "2147483647"}, + {"int64", int64(math.MaxInt64), "9223372036854775807"}, + {"uint", uint(math.MaxUint64), "18446744073709551615"}, + {"uint8", uint8(math.MaxUint8), "255"}, + {"uint16", uint16(math.MaxUint16), "65535"}, + {"uint32", uint32(math.MaxUint32), "4294967295"}, + {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, + {"string", "abc", "abc"}, + {"other", true, ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if out := ToString(test.in); test.out != out { + t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) + } + }) + } +} From 871f1de6b93835b069b6ef1bcbd823047a47c7a9 Mon Sep 17 00:00:00 2001 From: kvii <56432636+kvii@users.noreply.github.com> Date: Sat, 5 Nov 2022 11:52:08 +0800 Subject: [PATCH 45/56] fix logger path bug (#5836) --- utils/utils.go | 15 +++++++++++++-- utils/utils_unix_test.go | 33 +++++++++++++++++++++++++++++++++ utils/utils_windows_test.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 utils/utils_unix_test.go create mode 100644 utils/utils_windows_test.go diff --git a/utils/utils.go b/utils/utils.go index 90b4c8ea..2d87f4c2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" @@ -16,7 +16,18 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return s + string(filepath.Separator) } // FileWithLineNum return the file name and line number of the current file diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go new file mode 100644 index 00000000..da97aa2c --- /dev/null +++ b/utils/utils_unix_test.go @@ -0,0 +1,33 @@ +package utils + +import "testing" + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/Users/name/go/pkg/mod/gorm.io/", + }, + { + file: "/go/work/proj/gorm/utils/utils.go", + want: "/go/work/proj/gorm/", + }, + { + file: "/go/work/proj/gorm_alias/utils/utils.go", + want: "/go/work/proj/gorm_alias/", + }, + { + file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go new file mode 100644 index 00000000..d1734e0e --- /dev/null +++ b/utils/utils_windows_test.go @@ -0,0 +1,33 @@ +package utils + +import "testing" + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: `C:\Users\name\go\pkg\mod\gorm.io\gorm@v1.20.8\utils\utils.go`, + want: `C:\Users\name\go\pkg\mod\gorm.io`, + }, + { + file: `C:\go\work\proj\gorm\utils\utils.go`, + want: `C:\go\work\proj\gorm`, + }, + { + file: `C:\go\work\proj\gorm_alias\utils\utils.go`, + want: `C:\go\work\proj\gorm_alias`, + }, + { + file: `C:\go\work\proj\my.gorm.io\gorm\utils\utils.go`, + want: `C:\go\work\proj\my.gorm.io\gorm`, + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} From 1b9cd56c5336ba6e22936c289e586261b75d7b35 Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Thu, 10 Nov 2022 16:30:32 +0800 Subject: [PATCH 46/56] doc(README.md): add contributors (#5847) --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 5bb1be37..68fa6603 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ The fantastic ORM library for Golang, aims to be developer friendly. [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) +## Contributors + +Thank you for contributing to the GORM framework! + +[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors) + ## License © Jinzhu, 2013~time.Now From cef3de694d9615c574e82dfa0b50fc7ea2816f3e Mon Sep 17 00:00:00 2001 From: jessetang <1430482733@qq.com> Date: Sun, 13 Nov 2022 11:12:09 +0800 Subject: [PATCH 47/56] cleanup(prepare_stmt.go): unnecessary map delete (#5849) --- gorm.go | 2 +- prepare_stmt.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gorm.go b/gorm.go index 589fc4ff..89488b75 100644 --- a/gorm.go +++ b/gorm.go @@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { preparedStmt := &PreparedStmtDB{ ConnPool: db.ConnPool, - Stmts: map[string](*Stmt){}, + Stmts: make(map[string]*Stmt), Mux: &sync.RWMutex{}, PreparedSQL: make([]string, 0, 100), } diff --git a/prepare_stmt.go b/prepare_stmt.go index 7591e533..e09fe814 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -47,13 +47,12 @@ func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Reset() { db.Mux.Lock() defer db.Mux.Unlock() - for query, stmt := range db.Stmts { - delete(db.Stmts, query) + + for _, stmt := range db.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = map[string](*Stmt){} + db.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -93,7 +92,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact // Reason why cannot lock conn.PrepareContext // suppose the maxopen is 1, g1 is creating record and g2 is querying record. - // 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) From b6836c2d3ee91c0f0114736084d033f2b0a96748 Mon Sep 17 00:00:00 2001 From: kvii <56432636+kvii@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:48:13 +0800 Subject: [PATCH 48/56] fix bug in windows (#5844) * fix bug in windows * fix file name bug * test in unix like platform --- utils/utils.go | 2 +- utils/utils_unix_test.go | 7 ++++++- utils/utils_windows_test.go | 20 +++++++++++--------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index 2d87f4c2..e08533cd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -27,7 +27,7 @@ func sourceDir(file string) string { if filepath.Base(s) != "gorm.io" { s = dir } - return s + string(filepath.Separator) + return filepath.ToSlash(s) + "/" } // FileWithLineNum return the file name and line number of the current file diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go index da97aa2c..450cbe2a 100644 --- a/utils/utils_unix_test.go +++ b/utils/utils_unix_test.go @@ -1,6 +1,11 @@ +//go:build unix +// +build unix + package utils -import "testing" +import ( + "testing" +) func TestSourceDir(t *testing.T) { cases := []struct { diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go index d1734e0e..8b1c519d 100644 --- a/utils/utils_windows_test.go +++ b/utils/utils_windows_test.go @@ -1,6 +1,8 @@ package utils -import "testing" +import ( + "testing" +) func TestSourceDir(t *testing.T) { cases := []struct { @@ -8,20 +10,20 @@ func TestSourceDir(t *testing.T) { want string }{ { - file: `C:\Users\name\go\pkg\mod\gorm.io\gorm@v1.20.8\utils\utils.go`, - want: `C:\Users\name\go\pkg\mod\gorm.io`, + file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/Users/name/go/pkg/mod/gorm.io/`, }, { - file: `C:\go\work\proj\gorm\utils\utils.go`, - want: `C:\go\work\proj\gorm`, + file: `C:/go/work/proj/gorm/utils/utils.go`, + want: `C:/go/work/proj/gorm/`, }, { - file: `C:\go\work\proj\gorm_alias\utils\utils.go`, - want: `C:\go\work\proj\gorm_alias`, + file: `C:/go/work/proj/gorm_alias/utils/utils.go`, + want: `C:/go/work/proj/gorm_alias/`, }, { - file: `C:\go\work\proj\my.gorm.io\gorm\utils\utils.go`, - want: `C:\go\work\proj\my.gorm.io\gorm`, + file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, }, } for _, c := range cases { From 342310fba4fc56decf3d417925326db483734d7e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 21 Nov 2022 10:49:27 +0800 Subject: [PATCH 49/56] fix(FindInBatches): throw err if pk not exists (#5868) --- finisher_api.go | 11 ++++++++--- tests/query_test.go | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5516c0a1..cc07a126 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -231,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -514,8 +518,9 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { } // Pluck queries a single column from a model, returning in the slice dest. E.g.: -// var ages []int64 -// db.Model(&users).Pluck("age", &ages) +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { diff --git a/tests/query_test.go b/tests/query_test.go index eccf0133..fa8f09e8 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -408,6 +408,13 @@ func TestFindInBatchesWithError(t *testing.T) { if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } } func TestFillSmallerStruct(t *testing.T) { From f91313436abcfe7a28a488d5d6777b31a94f24fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Nov 2022 11:10:56 +0800 Subject: [PATCH 50/56] Fix group by with count logic --- finisher_api.go | 2 +- tests/count_test.go | 2 +- tests/go.mod | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index cc07a126..33d7a5a6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -465,7 +465,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { + if tx.RowsAffected != 1 { *count = tx.RowsAffected } diff --git a/tests/count_test.go b/tests/count_test.go index b71e3de5..2199dc6d 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -141,7 +141,7 @@ func TestCount(t *testing.T) { } DB.Create(sameUsers) - if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { + if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != int64(len(sameUsers)) { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } diff --git a/tests/go.mod b/tests/go.mod index 9c87ca34..23fc2cad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,13 +6,13 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.7 - golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect - golang.org/x/text v0.4.0 // indirect - gorm.io/driver/mysql v1.4.3 - gorm.io/driver/postgres v1.4.4 + github.com/mattn/go-sqlite3 v1.14.16 // indirect + golang.org/x/crypto v0.3.0 // indirect + gorm.io/driver/mysql v1.4.4 + gorm.io/driver/postgres v1.4.5 gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlserver v1.4.1 - gorm.io/gorm v1.24.0 + gorm.io/gorm v1.24.2 ) replace gorm.io/gorm => ../ From f931def33d23c9fd3c23ccb276e0f8bc17f8337f Mon Sep 17 00:00:00 2001 From: wjw1758548031 <46154774+wjw1758548031@users.noreply.github.com> Date: Thu, 1 Dec 2022 20:25:53 +0800 Subject: [PATCH 51/56] clear code syntax (#5889) * clear code syntax * clear code syntax --- finisher_api.go | 75 +++++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 33d7a5a6..b30ca24d 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,45 +326,48 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if result := queryTx.Find(dest, conds...); result.Error == nil { - if result.RowsAffected == 0 { - if c, ok := result.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - result.assignInterfacesToValue(where.Exprs) - } - } - // initialize with attrs, conds - if len(db.Statement.attrs) > 0 { - result.assignInterfacesToValue(db.Statement.attrs...) - } - - // initialize with attrs, conds - if len(db.Statement.assigns) > 0 { - result.assignInterfacesToValue(db.Statement.assigns...) - } - - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: - } - } - } - - return tx.Model(dest).Updates(assigns) - } - } else { + result := queryTx.Find(dest, conds...) + if result.Error != nil { tx.Error = result.Error + return tx } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + return tx } From d9525d4da45d343cdfb8641a72735330b9e86c88 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 1 Dec 2022 20:26:59 +0800 Subject: [PATCH 52/56] fix: skip append relation field to default db value (#5885) * fix: relation field returning * chore: gofumpt style --- schema/schema.go | 2 +- tests/associations_belongs_to_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/schema/schema.go b/schema/schema.go index 9b3d30f6..21e71c21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -230,7 +230,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } for _, field := range schema.Fields { - if field.HasDefaultValue && field.DefaultValueInterface == nil { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index f74799ce..a1f014d9 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -224,3 +225,28 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestBelongsToDefaultValue(t *testing.T) { + type Org struct { + ID string + } + type BelongsToUser struct { + OrgID string + Org Org `gorm:"default:NULL"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Config.DisableForeignKeyConstraintWhenMigrating = true + AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) + + tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) + tx.AutoMigrate(&BelongsToUser{}, &Org{}) + + user := &BelongsToUser{ + Org: Org{ + ID: "BelongsToUser_Org_1", + }, + } + err := DB.Create(&user).Error + AssertEqual(t, err, nil) +} From 4ec73c9bf46662bfef7a87d766e9c34661846385 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Mon, 19 Dec 2022 04:49:05 +0100 Subject: [PATCH 53/56] Add test case for embedded value selects (#5901) * Add test case for embedded value selects * Revert recycle struct optimisation to avoid pointer overwrites --- scan.go | 12 +++--------- tests/embedded_struct_test.go | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/scan.go b/scan.go index 0a26ce4b..12a77862 100644 --- a/scan.go +++ b/scan.go @@ -65,7 +65,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { if field == nil { @@ -241,9 +240,8 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var ( - elem reflect.Value - recyclableStruct = reflect.New(reflectValueType) - isArrayKind = reflectValue.Kind() == reflect.Array + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array ) if !update || reflectValue.Len() == 0 { @@ -275,11 +273,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } } } else { - if isPtr && db.RowsAffected > 0 { - elem = reflect.New(reflectValueType) - } else { - elem = recyclableStruct - } + elem = reflect.New(reflectValueType) } db.scanIntoStruct(rows, elem, values, fields, joinFields) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index e309d06c..ae69baca 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -36,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) { type EngadgetPost struct { BasePost BasePost `gorm:"Embedded"` - Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct ImageUrl string } @@ -74,13 +74,27 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("embedded struct's value should be scanned correctly") } - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) var egNews EngadgetPost if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") } + + var egPosts []EngadgetPost + if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { + t.Fatalf("no error should happen when query with embedded struct, but got %v", err) + } + expectAuthors := []string{"Edward", "George"} + for i, post := range egPosts { + t.Log(i, post.Author) + if want := expectAuthors[i]; post.Author.Name != want { + t.Errorf("expected author %s got %s", want, post.Author.Name) + } + } + } func TestEmbeddedPointerTypeStruct(t *testing.T) { From f3c6fc253356919e8ebbcf7bc50e8c7fe88802aa Mon Sep 17 00:00:00 2001 From: Nate Armstrong Date: Fri, 23 Dec 2022 00:51:01 -0800 Subject: [PATCH 54/56] Update func comments in chainable_api and FirstOr_ (#5935) Add comments to functions in chainable_api. Depending on the method, these comments add some additional context or details that are relevant when reading the function, link to the actual docs at gorm.io/docs, or provide examples of use. These comments should make GORM much more pleasant to use with an IDE that provides hoverable comments, and are minimal examples. Also add in-code documentation to FirstOrInit and FirstOrCreate. Almost all examples are directly pulled from the docs, with short comments explaining the code. Most examples omit the `db.Model(&User{})` for brevity, and would not actually work. Co-authored-by: Nate Armstrong --- chainable_api.go | 104 ++++++++++++++++++++++++++++++++++++++++++++++- finisher_api.go | 22 ++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 6d48d56b..68ec7a67 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -13,7 +13,7 @@ import ( // // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() @@ -22,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) { } // Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} @@ -45,6 +58,9 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { @@ -66,6 +82,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } // Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Distinct = true @@ -76,6 +97,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { } // Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -153,6 +182,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) { } // Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -162,6 +202,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { } // Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -171,6 +216,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { } // Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -203,6 +253,9 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { } // Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() @@ -214,6 +267,9 @@ func (db *DB) Group(name string) (tx *DB) { } // Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ @@ -222,7 +278,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { return } -// Order specify order when retrieve records from database +// Order specify order when retrieving records from database // // db.Order("name DESC") // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) @@ -247,6 +303,13 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Limit: &limit}) @@ -254,6 +317,13 @@ func (db *DB) Limit(limit int) (tx *DB) { } // Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) @@ -281,6 +351,7 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { // Preload preload associations with given conditions // +// // get all users, and preload all non-cancelled orders // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -291,12 +362,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.attrs = attrs return } +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.assigns = attrs diff --git a/finisher_api.go b/finisher_api.go index b30ca24d..39d9fca3 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,6 +294,16 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. // Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -321,6 +331,18 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. // Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ From bbd2bbe5217f7d3d3df5835748954f3cae6ebb68 Mon Sep 17 00:00:00 2001 From: Ning Date: Sat, 24 Dec 2022 11:02:11 +0800 Subject: [PATCH 55/56] fix:Issue migrating field with CURRENT_TIMESTAMP (#5906) Co-authored-by: ningfei --- migrator/migrator.go | 10 ++++++---- tests/migrate_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 9f8e3db8..b113b398 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -470,17 +470,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // check default value if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && !strings.EqualFold(field.DefaultValue, "NULL") dv, dvNotNull := columnType.DefaultValue() - if dvNotNull && field.DefaultValueInterface == nil { + if dvNotNull && !currentDefaultNotNull { // defalut value -> null alterColumn = true - } else if !dvNotNull && field.DefaultValueInterface != nil { + } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if dv != field.DefaultValue { + } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || + (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { // default value not equal // not both null - if !(field.DefaultValueInterface == nil && !dvNotNull) { + if currentDefaultNotNull || dvNotNull { alterColumn = true } } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 96b1d0e4..5f7e0749 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -757,6 +757,32 @@ func TestPrimarykeyID(t *testing.T) { } } +func TestCurrentTimestamp(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + type CurrentTimestampTest struct { + ID string `gorm:"primary_key"` + TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` + } + var err error + err = DB.Migrator().DropTable(&CurrentTimestampTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) +} + func TestUniqueColumn(t *testing.T) { if DB.Dialector.Name() != "mysql" { return From 775fa70af5a727f15ded94761fce5a1076603ca6 Mon Sep 17 00:00:00 2001 From: Defoo Li Date: Sat, 24 Dec 2022 12:14:23 +0800 Subject: [PATCH 56/56] DryRun for migrator (#5689) * DryRun for migrator * Update migrator.go * Update migrator.go Co-authored-by: Jinzhu --- migrator/migrator.go | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index b113b398..eafe7bb2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -8,9 +8,11 @@ import ( "reflect" "regexp" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) @@ -30,6 +32,16 @@ type Config struct { gorm.Dialector } +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string @@ -92,14 +104,19 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) - if !tx.Migrator().HasTable(value) { - if err := tx.Migrator().CreateTable(value); err != nil { + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, err := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } @@ -117,10 +134,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, dbName); err != nil { + if err := execTx.Migrator().AddColumn(value, dbName); err != nil { return err } - } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { // found, smart migrate return err } @@ -129,8 +146,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } @@ -138,16 +155,16 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } } for _, idx := range stmt.Schema.ParseIndexes() { - if !tx.Migrator().HasIndex(value, idx.Name) { - if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } }