From cec0d32aecc8d5068873304abe7f85e9409d4b10 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Jan 2022 18:48:32 +0800 Subject: [PATCH 01/87] Support use clause.Expression as argument --- clause/select_test.go | 17 +++++++++++++++++ statement.go | 2 ++ tests/go.mod | 4 +++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/clause/select_test.go b/clause/select_test.go index 9fce0783..18bc2693 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -43,6 +43,23 @@ func TestSelect(t *testing.T) { }, clause.From{}}, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "? as name", + Vars: []interface{}{clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, + }, + }, + }, + }, + }, clause.From{}}, + "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + }, } for idx, result := range results { diff --git a/statement.go b/statement.go index 146722a9..72359da2 100644 --- a/statement.go +++ b/statement.go @@ -183,6 +183,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { v.Build(stmt) case *clause.Expr: v.Build(stmt) + case clause.Expression: + v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) diff --git a/tests/go.mod b/tests/go.mod index 3233ea95..5415cf74 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,11 +3,13 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect + github.com/mattn/go-sqlite3 v1.14.10 // indirect + golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 From 98c4b78e4dcceea93eaaabd051f8c021e645e017 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Jan 2022 19:26:10 +0800 Subject: [PATCH 02/87] Add Session Initialized option --- gorm.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gorm.go b/gorm.go index fc70f684..a982bee4 100644 --- a/gorm.go +++ b/gorm.go @@ -96,6 +96,7 @@ type Session struct { DryRun bool PrepareStmt bool NewDB bool + Initialized bool SkipHooks bool SkipDefaultTransaction bool DisableNestedTransaction bool @@ -282,6 +283,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.NowFunc = config.NowFunc } + if config.Initialized { + tx = tx.getInstance() + } + return tx } From c0bea447b9eb707cfc1712d2d423f43309e247a2 Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Fri, 28 Jan 2022 22:16:42 +0800 Subject: [PATCH 03/87] fix: omit not work when use join (#5034) --- callbacks/query.go | 2 +- tests/connection_test.go | 3 +-- tests/joins_test.go | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index c2bbf5f9..49086354 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -100,7 +100,7 @@ func BuildQuerySQL(db *gorm.DB) { } if len(db.Statement.Joins) != 0 || len(joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} diff --git a/tests/connection_test.go b/tests/connection_test.go index 92b13dd6..7bc23009 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -9,7 +9,7 @@ import ( ) func TestWithSingleConnection(t *testing.T) { - var expectedName = "test" + expectedName := "test" var actualName string setSQL, getSQL := getSetSQL(DB.Dialector.Name()) @@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) { } return nil }) - if err != nil { t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) } diff --git a/tests/joins_test.go b/tests/joins_test.go index e276a74a..4c9cffae 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) { } } +func TestJoinWithOmit(t *testing.T) { + user := *GetUser("joins_with_omit", Config{Pets: 2}) + DB.Save(&user) + + results := make([]*User, 0) + + if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { + return + } + + if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { + t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) + return + } +} + func TestJoinCount(t *testing.T) { companyA := Company{Name: "A"} companyB := Company{Name: "B"} From 8c3673286dc6091967e2349687f0dbbaa55d66f8 Mon Sep 17 00:00:00 2001 From: Ning Date: Sun, 30 Jan 2022 18:17:06 +0800 Subject: [PATCH 04/87] preoload not allowd before count (#5023) Co-authored-by: ningfei --- errors.go | 2 ++ finisher_api.go | 4 ++++ tests/count_test.go | 10 ++++++++++ 3 files changed, 16 insertions(+) diff --git a/errors.go b/errors.go index 145614d9..49cbfe64 100644 --- a/errors.go +++ b/errors.go @@ -39,4 +39,6 @@ var ( ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") // ErrInvalidValueOfLength invalid values do not match length ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ) diff --git a/finisher_api.go b/finisher_api.go index 355d89bd..cbbd48cb 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -367,6 +367,10 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() + if len(tx.Statement.Preloads) > 0 { + tx.AddError(ErrPreloadNotAllowed) + return + } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { diff --git a/tests/count_test.go b/tests/count_test.go index 27d7ee60..b63a55fc 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -144,4 +144,14 @@ func TestCount(t *testing.T) { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } + + var count12 int64 + if err := DB.Table("users"). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { + t.Errorf("should returns preload not allowed error, but got %v", err) + } + } From 8d293d44dd7e4e6f61d759cb6c9a5be2c6523c5e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Jan 2022 22:00:56 +0800 Subject: [PATCH 05/87] Fix docker-compose test env for Mac M1 --- tests/docker-compose.yml | 4 ++-- tests/go.mod | 6 +++--- tests/tests_all.sh | 17 +++++++++++++++++ tests/tests_test.go | 11 ++++++----- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 05e0956e..9ab4ddb6 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: mysql: - image: 'mysql:latest' + image: 'mysql/mysql-server:latest' ports: - 9910:3306 environment: @@ -20,7 +20,7 @@ services: - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: - image: 'mcmoe/mssqldocker:latest' + image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - 9930:1433 environment: diff --git a/tests/go.mod b/tests/go.mod index 5415cf74..f2addaa1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,13 +8,13 @@ require ( github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - github.com/mattn/go-sqlite3 v1.14.10 // indirect - golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 // indirect + github.com/mattn/go-sqlite3 v1.14.11 // indirect + golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.2.1 - gorm.io/gorm v1.22.4 + gorm.io/gorm v1.22.5 ) replace gorm.io/gorm => ../ diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 79e0b5b7..e1f394e5 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -15,6 +15,23 @@ then cd .. fi +# SqlServer for Mac M1 +if [ -d tests ] +then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null + else + docker-compose start + fi + cd .. +fi + + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then diff --git a/tests/tests_test.go b/tests/tests_test.go index e26f358d..11b6f067 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) { PreferSimpleProtocol: true, }), &gorm.Config{}) case "sqlserver": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // CREATE DATABASE gorm; - // USE gorm; + // GO + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - // npm install -g sql-cli - // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 + // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; + // GO log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" From f19b84d104a2659af7b32c1cacd92a35efa33d34 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 30 Jan 2022 22:32:34 +0800 Subject: [PATCH 06/87] Fix github action --- .github/workflows/tests.yml | 8 ++++---- tests/tests_all.sh | 26 ++++++++++++++------------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 700af759..91a0abc9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlite ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh mysql: strategy: @@ -77,7 +77,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: @@ -120,7 +120,7 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: @@ -163,4 +163,4 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/tests/tests_all.sh b/tests/tests_all.sh index e1f394e5..5b9bae97 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -16,19 +16,21 @@ then fi # SqlServer for Mac M1 -if [ -d tests ] -then - cd tests - if [[ $(uname -a) == *" arm64" ]]; then - MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start - go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null - SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null - else - docker-compose start +if [[ -z $GITHUB_ACTION ]]; then + if [ -d tests ] + then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true + else + docker-compose start + fi + cd .. fi - cd .. fi From 581a879bf1ff1af7fcb361f0c6e4b201dbed75f0 Mon Sep 17 00:00:00 2001 From: Saurabh Thakre Date: Mon, 31 Jan 2022 17:26:28 +0530 Subject: [PATCH 07/87] Added comments to existing methods Added two comments to describe FirstOrInit and FirstOrCreate methods. --- finisher_api.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index cbbd48cb..3a179977 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,7 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } - +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -281,6 +281,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, From 416c4d0653ce6e0569e6c868963a6c3cc769c2fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Feb 2022 16:31:24 +0800 Subject: [PATCH 08/87] Test query with Or and soft delete --- tests/go.mod | 4 ++-- tests/query_test.go | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index f2addaa1..5488c17e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,11 +5,11 @@ go 1.14 require ( github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v4 v4.14.1 // indirect + github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed // indirect + golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/query_test.go b/tests/query_test.go index c99214b6..d10df180 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -512,7 +512,13 @@ func TestNotWithAllFields(t *testing.T) { func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + var count int64 + result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) + if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } From d22215129ee4747f9a9dd5b089d9f6920efc91ad Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Tue, 8 Feb 2022 17:06:10 +0800 Subject: [PATCH 09/87] fix: replace empty table name result in panic (#5048) * fix: replace empty name result in panic * fix: replace empty table name result in panic --- schema/naming.go | 8 +++++++- schema/naming_test.go | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/schema/naming.go b/schema/naming.go index 8407bffa..a4e3a75b 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -120,7 +120,13 @@ func (ns NamingStrategy) toDBName(name string) string { } if ns.NameReplacer != nil { - name = ns.NameReplacer.Replace(name) + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName } if ns.NoLowerCase { diff --git a/schema/naming_test.go b/schema/naming_test.go index c3e6bf92..1fdab9a0 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -197,3 +197,14 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { t.Errorf("invalid formatted name generated, got %v", formattedName) } } + +func TestReplaceEmptyTableName(t *testing.T) { + ns := NamingStrategy{ + SingularTable: true, + NameReplacer: strings.NewReplacer("Model", ""), + } + tableName := ns.TableName("Model") + if tableName != "Model" { + t.Errorf("invalid table name generated, got %v", tableName) + } +} From 4eeb839ceabb983b634f9cf9fffa1dd773b6803d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 15:17:19 +0800 Subject: [PATCH 10/87] Better support Stringer when explain SQL --- logger/logger.go | 14 ++++++++++- logger/sql.go | 24 ++++++++++++++---- tests/go.mod | 2 +- tests/sql_builder_test.go | 53 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 7 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 0c4ca4a0..2ffd28d5 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/utils" ) +// ErrRecordNotFound record not found error var ErrRecordNotFound = errors.New("record not found") // Colors @@ -30,13 +31,17 @@ const ( YellowBold = "\033[33;1m" ) -// LogLevel +// LogLevel log level type LogLevel int const ( + // Silent silent log level Silent LogLevel = iota + 1 + // Error error log level Error + // Warn warn log level Warn + // Info info log level Info ) @@ -45,6 +50,7 @@ type Writer interface { Printf(string, ...interface{}) } +// Config logger config type Config struct { SlowThreshold time.Duration Colorful bool @@ -62,16 +68,20 @@ type Interface interface { } var ( + // Discard Discard logger will print any log to ioutil.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, Colorful: true, }) + // Recorder Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +// New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " @@ -179,10 +189,12 @@ type traceRecorder struct { Err error } +// New new trace recorder func (l traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } +// Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() diff --git a/logger/sql.go b/logger/sql.go index 5ecb0ae2..e0be57c0 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,9 +30,12 @@ func isPrintable(s []byte) bool { var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { - var convertParams func(interface{}, int) - vars := make([]string, len(avars)) + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { @@ -64,10 +67,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case fmt.Stringer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = nullStr + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = nullStr + } } case []byte: if isPrintable(v) { diff --git a/tests/go.mod b/tests/go.mod index 5488c17e..3453f77b 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208050332-20e1d8d225ab // indirect + golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 237d807b..897f687f 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { } } +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint From df2365057bb6c809b03d470323238262a93a9685 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 17:23:16 +0800 Subject: [PATCH 11/87] Remove uncessary switch case --- statement.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/statement.go b/statement.go index 72359da2..23212642 100644 --- a/statement.go +++ b/statement.go @@ -179,10 +179,6 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } - case clause.Expr: - v.Build(stmt) - case *clause.Expr: - v.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: From a0aceeb33e7eabbecae5b7fd2eef874b1a77b086 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Feb 2022 17:39:01 +0800 Subject: [PATCH 12/87] Migrator AlterColumn with full data type --- gorm.go | 6 ++++++ migrator/migrator.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index a982bee4..7967b094 100644 --- a/gorm.go +++ b/gorm.go @@ -59,6 +59,7 @@ type Config struct { cacheStore *sync.Map } +// Apply update config to new config func (c *Config) Apply(config *Config) error { if config != c { *config = *c @@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error { return nil } +// AfterInitialize initialize plugins after db connected func (c *Config) AfterInitialize(db *DB) error { if db != nil { for _, plugin := range c.Plugins { @@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error { return nil } +// Option gorm option interface type Option interface { Apply(*Config) error AfterInitialize(*DB) error @@ -381,10 +384,12 @@ func (db *DB) getInstance() *DB { return db } +// Expr returns clause.Expr, which can be used to pass SQL expression as params func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } +// SetupJoinTable setup join table schema func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { var ( tx = db.getInstance() @@ -435,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac return nil } +// Use use plugin func (db *DB) Use(plugin Plugin) error { name := plugin.Name() if _, ok := db.Plugins[name]; ok { diff --git a/migrator/migrator.go b/migrator/migrator.go index 138917fb..80c4e2b3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -337,7 +337,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - fileType := clause.Expr{SQL: m.DataTypeOf(field)} + fileType := m.FullDataTypeOf(field) return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, From 19ac396a22668e2cdbd77a262de84478787989d0 Mon Sep 17 00:00:00 2001 From: li-jin-gou <97824201+li-jin-gou@users.noreply.github.com> Date: Tue, 15 Feb 2022 20:32:03 +0800 Subject: [PATCH 13/87] fix: isPrintable incorrect (#5076) * fix: isPrintable incorrect * fix: isPrintable incorrect * style: use ReplaceAll instead of Replace --- logger/sql.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index e0be57c0..04a2dbd4 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -19,9 +19,9 @@ const ( nullStr = "NULL" ) -func isPrintable(s []byte) bool { +func isPrintable(s string) bool { for _, r := range s { - if !unicode.IsPrint(rune(r)) { + if !unicode.IsPrint(r) { return false } } @@ -84,8 +84,8 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } } case []byte: - if isPrintable(v) { - vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } From 39d84cba5f7403dd60aee6f7aa2cb0b6bb48f82b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 16 Feb 2022 15:30:43 +0800 Subject: [PATCH 14/87] Add serializer support (#5078) * Update context * Update GormFieldValuer * Add Serializer * Add Serializer Interface * Refactor gorm field * Refactor setter, valuer * Add sync.Pool * Fix test * Add pool manager * Fix pool manager * Add poolInitializer * Add Serializer Scan support * Add Serializer Value method * Add serializer test * Finish Serializer * Fix JSONSerializer for postgres * Fix JSONSerializer for sqlserver * Test serializer tag * Add unixtime serializer * Update go.mod --- association.go | 64 ++-- callbacks/associations.go | 58 ++-- callbacks/create.go | 40 +-- callbacks/delete.go | 8 +- callbacks/preload.go | 28 +- callbacks/query.go | 2 +- callbacks/update.go | 14 +- finisher_api.go | 12 +- interfaces.go | 4 + scan.go | 32 +- schema/field.go | 564 ++++++++++++++++++++--------------- schema/field_test.go | 13 +- schema/interfaces.go | 11 + schema/pool.go | 62 ++++ schema/relationship.go | 5 +- schema/schema_helper_test.go | 3 +- schema/serializer.go | 125 ++++++++ schema/utils.go | 17 +- soft_delete.go | 4 +- statement.go | 16 +- tests/create_test.go | 2 +- tests/go.mod | 2 +- tests/serializer_test.go | 71 +++++ utils/utils.go | 17 +- 24 files changed, 773 insertions(+), 401 deletions(-) create mode 100644 schema/pool.go create mode 100644 schema/serializer.go create mode 100644 tests/serializer_test.go diff --git a/association.go b/association.go index 62c25b71..09e79ca6 100644 --- a/association.go +++ b/association.go @@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) - if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } @@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } @@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -186,11 +186,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.BelongsTo: tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -198,11 +198,11 @@ func (association *Association) Delete(values ...interface{}) error { case schema.HasOne, schema.HasMany: tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -228,11 +228,11 @@ func (association *Association) Delete(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -241,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { // clean up deleted values's foreign key - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { - if _, zero := rel.Field.ValueOf(data); !zero { - fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { @@ -253,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error { validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { @@ -261,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - association.Error = rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -329,14 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) @@ -344,7 +344,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) if clear { fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } @@ -373,7 +373,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } @@ -421,7 +421,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } @@ -429,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } @@ -453,12 +453,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: // clear old data if clear && len(values) == 0 { - association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -475,7 +475,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { @@ -486,7 +486,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ func (association *Association) buildCondition() *DB { var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) diff --git a/callbacks/associations.go b/callbacks/associations.go index 75bd6c6a..d6fd21de 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { break } - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) @@ -76,8 +76,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } @@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) } } @@ -149,8 +149,8 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } @@ -158,10 +158,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + ref.ForeignKey.Set(db.Statement.Context, f, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + ref.ForeignKey.Set(db.Statement.Context, elem, pv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) } } relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { - if pfv, ok := pf.ValueOf(elem); !ok { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } @@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) } } joins = reflect.Append(joins, joinValue) } appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) diff --git a/callbacks/create.go b/callbacks/create.go index 29113128..b0964e2b 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -117,9 +117,9 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -130,16 +130,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) } } } @@ -219,23 +219,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(rv, field.DefaultValueInterface) + field.Set(stmt.Context, rv, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + field.Set(stmt.Context, rv, curTime) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } @@ -259,23 +259,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.ReflectValue, field.DefaultValueInterface) + field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + field.Set(stmt.Context, stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 7f1e09ce..1fb5261c 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { switch rel.Type { case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false @@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) queryConds = append(queryConds, clause.IN{Column: column, Values: values}) @@ -123,7 +123,7 @@ func Delete(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -131,7 +131,7 @@ func Delete(config *Config) func(db *gorm.DB) { } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/callbacks/preload.go b/callbacks/preload.go index 41405a22..2363a8ca 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -48,7 +48,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return } @@ -63,11 +63,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinIndexValue) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +76,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,7 +92,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return } @@ -125,17 +125,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,7 +143,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(elem) + fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] @@ -154,7 +154,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) + reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,12 +162,12 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(data, elem.Interface()) + rel.Field.Set(db.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } diff --git a/callbacks/query.go b/callbacks/query.go index 49086354..03798859 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -40,7 +40,7 @@ func BuildQuerySQL(db *gorm.DB) { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } diff --git a/callbacks/update.go b/callbacks/update.go index 511e994e..4f07ca30 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) } } } @@ -137,13 +137,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } default: @@ -165,7 +165,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -178,7 +178,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -258,7 +258,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) + value, isZero := field.ValueOf(stmt.Context, updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() @@ -278,7 +278,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/finisher_api.go b/finisher_api.go index 3a179977..d2a8b981 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -83,7 +83,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { - if _, isZero := pf.ValueOf(reflectValue); isZero { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { return tx.callbacks.Create().Execute(tx) } } @@ -199,7 +199,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -216,11 +216,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { switch column := eq.Column.(type) { case string: if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { @@ -238,9 +238,9 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(reflectValue); !isZero { + if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) } } } diff --git a/interfaces.go b/interfaces.go index 44b2fced..ff0ca60a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -40,14 +40,17 @@ type SavePointerDialectorInterface interface { RollbackTo(tx *DB, name string) error } +// TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +// ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxCommitter tx commiter type TxCommitter interface { Commit() error Rollback() error @@ -58,6 +61,7 @@ type Valuer interface { GormValue(context.Context, *DB) clause.Expr } +// GetDBConnector SQL db connector type GetDBConnector interface { GetDBConn() (*sql.DB, error) } diff --git a/scan.go b/scan.go index b03b79b4..0da12daf 100644 --- a/scan.go +++ b/scan.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/schema" ) +// prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { @@ -54,11 +55,13 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch == nil { values[idx] = reflectValue.Interface() } else if field := sch.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + values[idx] = field.NewValuePool.Get() + defer field.NewValuePool.Put(values[idx]) continue } } @@ -77,21 +80,21 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) + field.Set(db.Statement.Context, reflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) - value := reflect.ValueOf(values[idx]).Elem() + relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { continue } + relValue.Set(reflect.New(relValue.Type().Elem())) } - field.Set(relValue, values[idx]) + field.Set(db.Statement.Context, relValue, values[idx]) } } } @@ -99,14 +102,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re } } +// ScanMode scan data mode type ScanMode uint8 +// scan modes const ( ScanInitialized ScanMode = 1 << 0 // 1 ScanUpdate ScanMode = 1 << 1 // 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) +// Scan scan rows into db statement func Scan(rows *sql.Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() @@ -138,7 +144,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}, []map[string]interface{}: + case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -149,11 +155,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - if values, ok := dest.([]map[string]interface{}); ok { - values = append(values, mapValue) - } else if values, ok := dest.(*[]map[string]interface{}); ok { - *values = append(*values, mapValue) - } + *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -174,7 +176,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - if reflectValue.Kind() == reflect.Interface { + for reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -244,7 +246,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflectValue.Index(int(db.RowsAffected)) if onConflictDonothing { for _, field := range fields { - if _, ok := field.ValueOf(elem); !ok { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { db.RowsAffected++ goto BEGIN } diff --git a/schema/field.go b/schema/field.go index 485bbdf3..319f3693 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,6 +1,7 @@ package schema import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -14,12 +15,21 @@ import ( "gorm.io/gorm/utils" ) -type DataType string +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) -type TimeType int64 - -var TimeReflectType = reflect.TypeOf(time.Time{}) +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) +// GORM time types const ( UnixTime TimeType = 1 UnixSecond TimeType = 2 @@ -27,6 +37,7 @@ const ( UnixNanosecond TimeType = 4 ) +// GORM fields types const ( Bool DataType = "bool" Int DataType = "int" @@ -37,6 +48,7 @@ const ( Bytes DataType = "bytes" ) +// Field is the representation of model schema's field type Field struct { Name string DBName string @@ -49,9 +61,9 @@ type Field struct { Creatable bool Updatable bool Readable bool - HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType + HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool @@ -60,6 +72,7 @@ type Field struct { Size int Precision int Scale int + IgnoreMigration bool FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -68,27 +81,39 @@ type Field struct { Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error - IgnoreMigration bool + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface + NewValuePool FieldNewValuePool } +// ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { - var err error + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) field := &Field{ Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, Creatable: true, Updatable: true, Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], AutoIncrementIncrement: 1, } @@ -97,7 +122,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - // if field is valuer, used its value or first fields as data type + // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { @@ -105,31 +130,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { - rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { - for i := 0; i < rv.Type().NumField(); i++ { - newFieldType := rv.Type().Field(i).Type + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - - if rv.Type() != reflect.Indirect(fieldValue).Type() { + if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } - - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } } } } @@ -138,19 +169,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if dbName, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = dbName - } - - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } - - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { - field.AutoIncrement = true - field.HasDefaultValue = true + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + field.Serializer = v + } else { + var serializerName = field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { @@ -176,20 +211,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Scale, _ = strconv.Atoi(s) } - if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } - - if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { - field.Unique = true - } - - if val, ok := field.TagSettings["COMMENT"]; ok { - field.Comment = val - } - // default value is function or null or blank (primary keys) field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && @@ -225,7 +246,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, `"`) @@ -236,17 +256,15 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } case reflect.Array, reflect.Slice: - if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes } } - field.GORMDataType = field.DataType - if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } @@ -346,8 +364,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && - (ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: @@ -410,95 +429,122 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // ValueOf - switch { - case len(field.StructField.Index) == 1: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - return fieldValue.Interface(), fieldValue.IsZero() - } - default: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - v := reflect.Indirect(value) - - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) - } else { - v = v.Field(-idx - 1) - - if v.Type().Elem().Kind() != reflect.Struct { - return nil, true - } - - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } + // Setup NewValuePool + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool } - return v.Interface(), v.IsZero() } } - // ReflectValueOf - switch { - case len(field.StructField.Index) == 1: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) - } - default: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - v := reflect.Indirect(value) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } + + // ValueOf returns field's value and if it is zero + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() } else { - v = v.Field(-fieldIdx - 1) - } - - if v.Kind() == reflect.Ptr { - if v.Type().Elem().Kind() == reflect.Struct { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - } - - if idx < len(field.StructField.Index)-1 { - v = v.Elem() - } + return nil, true } } - return v + } + + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } + + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + if zero { + return value, zero + } + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer + } + + return serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, false } } - fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + // ReflectValueOf returns field's reflect value + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } + } + } + return v + } + + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) // Optimal value type acquisition for v reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) fieldType := field.FieldType.Elem() if reflectValType.AssignableTo(fieldType) { @@ -521,13 +567,16 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return } else { - err = setter(value, reflectV.Elem().Interface()) + err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = setter(value, v) + err = setter(ctx, value, v) } } else { return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) @@ -540,191 +589,201 @@ func (field *Field) setupValuerAndSetter() { // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) + } case bool: - field.ReflectValueOf(value).SetBool(data) - case *bool: - if data != nil { - field.ReflectValueOf(value).SetBool(*data) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data) case int64: - if data > 0 { - field.ReflectValueOf(value).SetBool(true) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) - field.ReflectValueOf(value).SetBool(b) + field.ReflectValueOf(ctx, value).SetBool(b) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } case int64: - field.ReflectValueOf(value).SetInt(data) + field.ReflectValueOf(ctx, value).SetInt(data) case int: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetInt(i) + field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { - field.ReflectValueOf(value).SetInt(0) + field.ReflectValueOf(ctx, value).SetInt(0) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } case uint64: - field.ReflectValueOf(value).SetUint(data) + field.ReflectValueOf(ctx, value).SetUint(data) case uint: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) } else { - field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetUint(i) + field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } case float64: - field.ReflectValueOf(value).SetFloat(data) + field.ReflectValueOf(ctx, value).SetFloat(data) case float32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValueOf(value).SetFloat(i) + field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } case string: - field.ReflectValueOf(value).SetString(data) + field.ReflectValueOf(ctx, value).SetString(data) case []byte: - field.ReflectValueOf(value).SetString(string(data)) + field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(utils.ToString(data)) + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: - field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } @@ -732,41 +791,49 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } case time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { - field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } case time.Time: - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil @@ -778,27 +845,27 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -813,32 +880,61 @@ func (field *Field) setupValuerAndSetter() { } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { - field.Set = func(value reflect.Value, v interface{}) (err error) { - return fallbackSetter(value, v, field.Set) + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) } } } } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } } diff --git a/schema/field_test.go b/schema/field_test.go index 8fa46b87..300e375b 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "database/sql" "reflect" "sync" @@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index 98abffbd..a75a33c0 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,22 +4,33 @@ import ( "gorm.io/gorm/clause" ) +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string } +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface } +// QueryClausesInterface query clauses interface type QueryClausesInterface interface { QueryClauses(*Field) []clause.Interface } +// UpdateClausesInterface update clauses interface type UpdateClausesInterface interface { UpdateClauses(*Field) []clause.Interface } +// DeleteClausesInterface delete clauses interface type DeleteClausesInterface interface { DeleteClauses(*Field) []clause.Interface } diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 00000000..f5c73153 --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,62 @@ +package schema + +import ( + "reflect" + "sync" + "time" +) + +// sync pools +var ( + normalPool sync.Map + stringPool = &sync.Pool{ + New: func() interface{} { + var v string + ptrV := &v + return &ptrV + }, + } + intPool = &sync.Pool{ + New: func() interface{} { + var v int64 + ptrV := &v + return &ptrV + }, + } + uintPool = &sync.Pool{ + New: func() interface{} { + var v uint64 + ptrV := &v + return &ptrV + }, + } + floatPool = &sync.Pool{ + New: func() interface{} { + var v float64 + ptrV := &v + return &ptrV + }, + } + boolPool = &sync.Pool{ + New: func() interface{} { + var v bool + ptrV := &v + return &ptrV + }, + } + timePool = &sync.Pool{ + New: func() interface{} { + var v time.Time + ptrV := &v + return &ptrV + }, + } + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/schema/relationship.go b/schema/relationship.go index c5d3dcad..eae8ab0b 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,6 +1,7 @@ package schema import ( + "context" "fmt" "reflect" "strings" @@ -576,7 +577,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } -func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} @@ -616,7 +617,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } } - _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 6d2bc664..9abaecba 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "fmt" "reflect" "strings" @@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - fv, _ := s.FieldsByDBName[k].ValueOf(value) + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 00000000..68597538 --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,125 @@ +package schema + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct { +} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil { + err = field.Set(ctx, dst, t.Time) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.ValueOf(v).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} diff --git a/schema/utils.go b/schema/utils.go index e005cc74..2720c530 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -1,6 +1,7 @@ package schema import ( + "context" "reflect" "regexp" "strings" @@ -59,13 +60,13 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct } // GetRelationsValues get relations's values from a reflect value -func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) @@ -97,7 +98,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle } // GetIdentityFieldValuesMap get identity map from fields -func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} @@ -110,7 +111,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - results[0][idx], zero = field.ValueOf(reflectValue) + results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } @@ -135,7 +136,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(elem) + fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } @@ -155,12 +156,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map } // GetIdentityFieldValuesMapFromValues get identity map from fields -func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { - rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } diff --git a/soft_delete.go b/soft_delete.go index 4582161d..ba6d2118 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -135,7 +135,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -143,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { diff --git a/statement.go b/statement.go index 23212642..cb471776 100644 --- a/statement.go +++ b/statement.go @@ -389,7 +389,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -403,7 +403,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(destValue, value) + field.Set(stmt.Context, destValue, value) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } else { - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } else { stmt.AddError(ErrInvalidField) @@ -603,7 +603,7 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, _ := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { @@ -617,7 +617,7 @@ func (stmt *Statement) Changed(fields ...string) bool { destValue = destValue.Elem() } - changedValue, zero := field.ValueOf(destValue) + changedValue, zero := field.ValueOf(stmt.Context, destValue) return !zero && !utils.AssertEqual(changedValue, fieldValue) } } diff --git a/tests/create_test.go b/tests/create_test.go index af2abdb0..2b23d440 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) { {"name": "create_from_map_3", "Age": 20}, } - if err := DB.Model(&User{}).Create(datas).Error; err != nil { + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } diff --git a/tests/go.mod b/tests/go.mod index 3453f77b..35db92e6 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect - golang.org/x/crypto v0.0.0-20220208233918-bba287dce954 // indirect + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.2.6 diff --git a/tests/serializer_test.go b/tests/serializer_test.go new file mode 100644 index 00000000..3ed733d9 --- /dev/null +++ b/tests/serializer_test.go @@ -0,0 +1,71 @@ +package tests_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +type SerializerStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + EncryptedString EncryptedString +} + +type Roles []string +type EncryptedString string + +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %v", dbValue) + } + return nil +} + +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +func TestSerializer(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) +} diff --git a/utils/utils.go b/utils/utils.go index f00f92ba..28ca0daf 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } -func CheckTruth(val interface{}) bool { - if v, ok := val.(bool); ok { - return v +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if !strings.EqualFold(val, "false") && val != "" { + return true + } } - - if v, ok := val.(string); ok { - v = strings.ToLower(v) - return v != "false" - } - - return !reflect.ValueOf(val).IsZero() + return false } func ToStringKey(values ...interface{}) string { From 0af95f509a3284bb94393946e0a83aeaf954f304 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 16:59:22 +0800 Subject: [PATCH 15/87] Enhance migrator Columntype interface (#5088) * Update Migrator ColumnType interface * Update MigrateColumn Test * Upgrade test drivers * Fix typo --- migrator.go | 13 ++++- migrator/column_type.go | 107 ++++++++++++++++++++++++++++++++++++++++ migrator/migrator.go | 39 +++++++++++++-- tests/go.mod | 9 ++-- tests/migrate_test.go | 31 ++++++++++-- 5 files changed, 185 insertions(+), 14 deletions(-) create mode 100644 migrator/column_type.go diff --git a/migrator.go b/migrator.go index 2a8b4254..52443877 100644 --- a/migrator.go +++ b/migrator.go @@ -1,6 +1,8 @@ package gorm import ( + "reflect" + "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -33,14 +35,23 @@ type ViewOption struct { Query *DB } +// ColumnType column type interface type ColumnType interface { Name() string - DatabaseTypeName() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) Length() (length int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool) Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) } +// Migrator migrator interface type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error diff --git a/migrator/column_type.go b/migrator/column_type.go new file mode 100644 index 00000000..eb8d1b7f --- /dev/null +++ b/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimayKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. lke `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 80c4e2b3..9695f312 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -30,10 +30,12 @@ type Config struct { gorm.Dialector } +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } +// RunWithValue run migration with statement value func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error return fc(stmt) } +// DataTypeOf return field's db data type func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { @@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } +// FullDataTypeOf returns field's db full data type func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) @@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } -// AutoMigrate +// AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) @@ -156,12 +160,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +// GetTables returns tables func (m Migrator) GetTables() (tableList []string, err error) { err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). Scan(&tableList).Error return } +// CreateTable create table in database for values func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) @@ -252,6 +258,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { return nil } +// DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { @@ -265,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error { return nil } +// HasTable returns table exists or not for value, value could be a struct or string func (m Migrator) HasTable(value interface{}) bool { var count int64 @@ -276,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +// RenameTable rename table from oldName to newName func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable interface{} if v, ok := oldName.(string); ok { @@ -303,12 +312,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } -func (m Migrator) AddColumn(value interface{}, field string) error { +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field - f := stmt.Schema.LookUpField(field) + f := stmt.Schema.LookUpField(name) if f == nil { - return fmt.Errorf("failed to look up field with name: %s", field) + return fmt.Errorf("failed to look up field with name: %s", name) } if !f.IgnoreMigration { @@ -322,6 +332,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { }) } +// DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { @@ -334,6 +345,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { }) } +// AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { @@ -348,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +// HasColumn check has column `field` for value or not func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -366,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +// RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { @@ -383,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +// MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) @@ -448,7 +463,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) } return @@ -457,10 +472,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } +// CreateView create view func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } +// DropView drop view func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -487,6 +504,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +// GuessConstraintAndTable guess statement's constraint and it's table based on name func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { if stmt.Schema == nil { return nil, nil, stmt.Table @@ -531,6 +549,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ return nil, nil, stmt.Schema.Table } +// CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -554,6 +573,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { }) } +// DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -566,6 +586,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { }) } +// HasConstraint check has constraint or not func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -586,6 +607,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } +// BuildIndexOptions build index options func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) @@ -607,10 +629,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } +// BuildIndexOptionsInterface build index options interface type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } +// CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -642,6 +666,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { }) } +// DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -652,6 +677,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { }) } +// HasIndex check has index `name` or not func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -669,6 +695,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } +// RenameIndex rename index from oldName to newName func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( @@ -678,6 +705,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error }) } +// CurrentDatabase returns current database name func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return @@ -781,6 +809,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i return } +// CurrentTable returns current statement's table expression func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { if stmt.TableExpr != nil { return *stmt.TableExpr diff --git a/tests/go.mod b/tests/go.mod index 35db92e6..0cd03637 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,17 +3,16 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/denisenkom/go-mssqldb v0.12.0 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.2.3 - gorm.io/driver/postgres v1.2.3 - gorm.io/driver/sqlite v1.2.6 - gorm.io/driver/sqlserver v1.2.1 + gorm.io/driver/mysql v1.3.0 + gorm.io/driver/postgres v1.3.0 + gorm.io/driver/sqlite v1.3.0 + gorm.io/driver/sqlserver v1.3.0 gorm.io/gorm v1.22.5 ) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index aa0a84ab..5e9c01fa 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { - fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint @@ -313,9 +313,15 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { + fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] + sqlite := DB.Dialector.Name() == "sqlite" + sqlserver := DB.Dialector.Name() == "sqlserver" + type ColumnStruct struct { gorm.Model Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -340,10 +346,29 @@ func TestMigrateColumns(t *testing.T) { stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { - if columnType.Name() == "name" { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); (fullSupported || ok) && !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } } } From e0b4e0ec8f938ac055e99c5b37e0cdb9bf6e2ad5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 17:08:11 +0800 Subject: [PATCH 16/87] Update auto stale days --- .github/workflows/invalid_question.yml | 4 ++-- .github/workflows/missing_playground.yml | 4 ++-- .github/workflows/stale.yml | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index dfd2ddd9..868bcc34 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -13,10 +13,10 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:invalid question" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index cdb097de..3efc90f7 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,9 +13,9 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index d5419295..e0be186f 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,9 +13,9 @@ jobs: uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" - days-before-stale: 60 - days-before-close: 30 + stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" + days-before-stale: 360 + days-before-close: 180 stale-issue-label: "status:stale" exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' stale-pr-label: 'status:stale' From 48ced75d1d8d8aab844ab29787ae97337095b8e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 23:42:20 +0800 Subject: [PATCH 17/87] Improve support for AutoMigrate --- migrator/column_type.go | 4 ++-- migrator/migrator.go | 24 +++++++++++++++++++++ tests/go.mod | 10 ++++----- tests/migrate_test.go | 47 ++++++++++++++++++++++++++++++----------- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/migrator/column_type.go b/migrator/column_type.go index eb8d1b7f..cc1331b9 100644 --- a/migrator/column_type.go +++ b/migrator/column_type.go @@ -11,7 +11,7 @@ type ColumnType struct { NameValue sql.NullString DataTypeValue sql.NullString ColumnTypeValue sql.NullString - PrimayKeyValue sql.NullBool + PrimaryKeyValue sql.NullBool UniqueValue sql.NullBool AutoIncrementValue sql.NullBool LengthValue sql.NullInt64 @@ -51,7 +51,7 @@ func (ct ColumnType) ColumnType() (columnType string, ok bool) { // PrimaryKey returns the column is primary key or not. func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { - return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid } // AutoIncrement returns the column is auto increment or not. diff --git a/migrator/migrator.go b/migrator/migrator.go index 9695f312..a50bb3ff 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -436,6 +436,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } diff --git a/tests/go.mod b/tests/go.mod index 0cd03637..1c1fb238 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.0 - gorm.io/driver/postgres v1.3.0 - gorm.io/driver/sqlite v1.3.0 - gorm.io/driver/sqlserver v1.3.0 - gorm.io/gorm v1.22.5 + gorm.io/driver/mysql v1.3.1 + gorm.io/driver/postgres v1.3.1 + gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/sqlserver v1.3.1 + gorm.io/gorm v1.23.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5e9c01fa..94f562b4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) { for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v---", m) + t.Fatalf("Failed to create table for %#v", m) } } @@ -313,15 +313,16 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { - fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" type ColumnStruct struct { gorm.Model - Name string - Age int `gorm:"default:18;comment:my age"` - Code string `gorm:"unique"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -332,13 +333,20 @@ func TestMigrateColumns(t *testing.T) { type ColumnStruct2 struct { gorm.Model - Name string `gorm:"size:100"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string } - if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { @@ -348,7 +356,7 @@ func TestMigrateColumns(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "id": - if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { + if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "name": @@ -356,20 +364,35 @@ func TestMigrateColumns(t *testing.T) { if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } - if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) } case "age": - if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { + if v, ok := columnType.DefaultValue(); !ok || v != "18" { t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } - if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code": - if v, ok := columnType.Unique(); (fullSupported || ok) && !v { + if v, ok := columnType.Unique(); !ok || !v { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } } } } From 5edc78116fe46a7d001db52d80a78f97756ac1ad Mon Sep 17 00:00:00 2001 From: sammyrnycreal Date: Mon, 14 Feb 2022 14:13:26 -0500 Subject: [PATCH 18/87] Fixed the use of "or" to be " OR ", to account for words that contain "or" or "and" (e.g., 'score', 'band') in a sql statement as the name of a field. --- clause/where.go | 39 ++++++++++++++++++++++----------------- clause/where_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/clause/where.go b/clause/where.go index 20a01136..10b6df85 100644 --- a/clause/where.go +++ b/clause/where.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + // Where where clause type Where struct { Exprs []Expression @@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) { } } - buildExprs(where.Exprs, builder, " AND ") + buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { @@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") + builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } @@ -46,23 +51,23 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) case NamedExpr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } @@ -110,10 +115,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) } } @@ -131,10 +136,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) } } @@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(" AND ") + builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { @@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { - sql := strings.ToLower(e.SQL) - if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } diff --git a/clause/where_test.go b/clause/where_test.go index 272c7b76..35e3dbee 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -66,6 +66,45 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, + }, } for idx, result := range results { From f3547e00cc786e0b07206c775f3b7fe19164f56f Mon Sep 17 00:00:00 2001 From: Gilad Weiss Date: Sun, 20 Feb 2022 02:33:12 +0200 Subject: [PATCH 19/87] Inherit clone flag (NewDB) on transaction creation (#5012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Inherit clone flag (NewDB) on transaction creation I find it very reassuring to know that after a finisher API, I get a clean db object for my next queries. If you look at the example in https://gorm.io/docs i’d see many queries running one after the other.. but in reality they wouldn’t work as the they are portrayed and that’s because in default mode NewDB is false and will make all the clauses stay even after a finisher API. My solution is just to have the value of the clone flag in the “parent” db object, be injected to its children transactions. * Fix typo --- finisher_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index d2a8b981..f994ec31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -590,7 +590,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions err error ) From 664c5fb7672863b38080bb2147403b5d67f2593c Mon Sep 17 00:00:00 2001 From: codingxh <94290868+codingxh@users.noreply.github.com> Date: Sun, 20 Feb 2022 19:55:04 +0800 Subject: [PATCH 20/87] strings.replace -> strings.replaceAll (#5095) Co-authored-by: huquan --- logger/sql.go | 8 ++++---- logger/sql_test.go | 2 +- schema/naming.go | 2 +- tests/sql_builder_test.go | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 04a2dbd4..c8b194c3 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -75,10 +75,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper } else { vars[idx] = nullStr } @@ -94,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: - vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -111,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 71aa841a..c5b181a9 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper } func TestExplainSQL(t *testing.T) { diff --git a/schema/naming.go b/schema/naming.go index a4e3a75b..125094bc 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -174,7 +174,7 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 897f687f..bc917c32 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -460,16 +460,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { func replaceQuoteInSQL(sql string) string { // convert single quote into double quote - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) // convert dialect speical quote into double quote switch DB.Dialector.Name() { case "postgres": - sql = strings.Replace(sql, `"`, `"`, -1) + sql = strings.ReplaceAll(sql, `"`, `"`) case "mysql", "sqlite": - sql = strings.Replace(sql, "`", `"`, -1) + sql = strings.ReplaceAll(sql, "`", `"`) case "sqlserver": - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) } return sql From 7837fb6fa001ef78bc76e66b48445dee7b2db37b Mon Sep 17 00:00:00 2001 From: Qt Date: Sun, 20 Feb 2022 21:19:15 +0800 Subject: [PATCH 21/87] fix typo in TxCommitter interface comment & improve CheckTruth, chek val empty first (#5094) * fix typo in TxCommitter interface comment * improve CheckTruth, chek val empty first --- interfaces.go | 2 +- utils/utils.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/interfaces.go b/interfaces.go index ff0ca60a..44a85cb5 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,7 +50,7 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxCommitter tx commiter +// TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error diff --git a/utils/utils.go b/utils/utils.go index 28ca0daf..296917b9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -39,7 +39,7 @@ func IsValidDBNameChar(c rune) bool { // CheckTruth check string true or not func CheckTruth(vals ...string) bool { for _, val := range vals { - if !strings.EqualFold(val, "false") && val != "" { + if val != "" && !strings.EqualFold(val, "false") { return true } } From b1201fce4efa60b464a1b260869a24d809607f53 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Feb 2022 17:48:13 +0800 Subject: [PATCH 22/87] Fix update with customized time type, close #5101 --- callbacks/update.go | 12 ++++++------ schema/field.go | 8 ++++---- tests/go.mod | 4 ++-- tests/postgres_test.go | 18 +++++++++++++++--- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 4f07ca30..4a2e5c79 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -232,10 +232,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) - } else if field.GORMDataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } } } @@ -264,10 +264,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() } isZero = false } diff --git a/schema/field.go b/schema/field.go index 319f3693..8c793f93 100644 --- a/schema/field.go +++ b/schema/field.go @@ -293,6 +293,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if val, ok := field.TagSettings["TYPE"]; ok { switch DataType(strings.ToLower(val)) { case Bool, Int, Uint, Float, String, Time, Bytes: @@ -302,10 +306,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.GORMDataType == "" { - field.GORMDataType = field.DataType - } - if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/tests/go.mod b/tests/go.mod index 1c1fb238..cefe6f96 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.1 + gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlserver v1.3.1 - gorm.io/gorm v1.23.0 + gorm.io/gorm v1.23.1 ) replace gorm.io/gorm => ../ diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85671864..418b713e 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "github.com/google/uuid" "github.com/lib/pq" @@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model - Name string `gorm:"check:name_checker,name <> ''"` - Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` - Things pq.StringArray `gorm:"type:text[]"` + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { @@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) { if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } + + harumph.Name = "jinzhu1" + if err := DB.Save(&harumph).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } } type Post struct { From 45ef1da7e4853441e59af06800ed7c672f15bc7c Mon Sep 17 00:00:00 2001 From: Michael Nussbaum Date: Wed, 23 Feb 2022 21:10:20 -0500 Subject: [PATCH 23/87] Fix naming longer then 64 chars with dots in table (#5045) Ensures that foreign key relationships and indexes are given syntactically valid names when their name length exceeds 64 characters and they contained dot characters within the name. This is most often relevant when a Postgres table name is fully qualified by including its schema as part of its name --- schema/naming.go | 3 +-- schema/naming_test.go | 2 +- schema/relationship_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index 125094bc..47a2b363 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -3,7 +3,6 @@ package schema import ( "crypto/sha1" "encoding/hex" - "fmt" "regexp" "strings" "unicode/utf8" @@ -95,7 +94,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 1fdab9a0..3f598c33 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -193,7 +193,7 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { ns := NamingStrategy{} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") - if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index e2cf11a9..40ffc324 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -576,3 +576,39 @@ func TestHasManySameForeignKey(t *testing.T) { References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } + +type Author struct { + gorm.Model +} + +type Book struct { + gorm.Model + Author Author + AuthorID uint +} + +func (Book) TableName() string { + return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" +} + +func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { + s, err := schema.Parse( + &Book{}, + &sync.Map{}, + schema.NamingStrategy{}, + ) + if err != nil { + t.Fatalf("Failed to parse schema") + } + + expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" + constraint := s.Relationships.Relations["Author"].ParseConstraint() + + if constraint.Name != expectedConstraintName { + t.Fatalf( + "expected constraint name %s, got %s", + expectedConstraintName, + constraint.Name, + ) + } +} From 3741f258d053c0ac145392b5669c0cc62ddc0f15 Mon Sep 17 00:00:00 2001 From: jing1 Date: Thu, 24 Feb 2022 10:21:27 +0800 Subject: [PATCH 24/87] feat: support gob serialize (#5108) --- schema/serializer.go | 36 ++++++++++++++++++++++++++++++++++-- tests/serializer_test.go | 15 +++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 68597538..09da6d9e 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -1,11 +1,12 @@ package schema import ( + "bytes" "context" "database/sql" "database/sql/driver" + "encoding/gob" "encoding/json" - "errors" "fmt" "reflect" "strings" @@ -32,6 +33,7 @@ func GetSerializer(name string) (serializer SerializerInterface, ok bool) { func init() { RegisterSerializer("json", JSONSerializer{}) RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) } // Serializer field value serializer @@ -83,7 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, case string: bytes = []byte(v) default: - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } err = json.Unmarshal(bytes, fieldValue.Interface()) @@ -123,3 +125,33 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } return } + +// GobSerializer gob serializer +type GobSerializer struct { +} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 3ed733d9..a8a4e28f 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,11 +19,20 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } type Roles []string + +type Job struct { + Title string + Number int + Location string + IsIntern bool +} + type EncryptedString string func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -56,6 +65,12 @@ func TestSerializer(t *testing.T) { Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, } if err := DB.Create(&data).Error; err != nil { From 6a18a15c93e17d513687993294e045574117266a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Feb 2022 10:48:23 +0800 Subject: [PATCH 25/87] Refactor check missing where condition --- callbacks/delete.go | 19 +++++++------------ callbacks/helper.go | 16 ++++++++++++++++ callbacks/update.go | 20 +++++++------------- soft_delete.go | 11 ++--------- tests/update_test.go | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261c..84f446a3 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880..a5eb047e 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4a2e5c79..da03261e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { diff --git a/soft_delete.go b/soft_delete.go index ba6d2118..6d646288 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9b..41ea5d27 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 397b583b8ecc5a31c838db5822fe1003b53a91ef Mon Sep 17 00:00:00 2001 From: chenrui Date: Fri, 25 Feb 2022 22:38:48 +0800 Subject: [PATCH 26/87] fix: query scanner in single column --- scan.go | 12 +++++++++++- tests/query_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/scan.go b/scan.go index 0da12daf..a1cb582e 100644 --- a/scan.go +++ b/scan.go @@ -272,7 +272,17 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + if update { + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + } else { + elem := reflect.New(reflectValueType) + db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + if isPtr { + db.Statement.ReflectValue.Set(elem) + } else { + db.Statement.ReflectValue.Set(elem.Elem()) + } + } } default: db.AddError(rows.Scan(dest)) diff --git a/tests/query_test.go b/tests/query_test.go index d10df180..6542774a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +type DoubleInt64 struct { + data int64 +} + +func (t *DoubleInt64) Scan(val interface{}) error { + switch v := val.(type) { + case int64: + t.data = v * 2 + return nil + default: + return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) + } +} + +// https://github.com/go-gorm/gorm/issues/5091 +func TestQueryScannerWithSingleColumn(t *testing.T) { + user := User{Name: "scanner_raw_1", Age: 10} + DB.Create(&user) + + var result1 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( + "age", &result1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result1.data, 20) + + var result2 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( + "age").Scan(&result2).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result2.data, 20) +} From f2edda50e11728e7aee6b1d4c961d575f7afbb2d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 25 Feb 2022 10:48:23 +0800 Subject: [PATCH 27/87] Refactor check missing where condition --- callbacks/delete.go | 19 +++++++------------ callbacks/helper.go | 16 ++++++++++++++++ callbacks/update.go | 20 +++++++------------- soft_delete.go | 11 ++--------- tests/update_test.go | 2 +- 5 files changed, 33 insertions(+), 35 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261c..84f446a3 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880..a5eb047e 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4a2e5c79..da03261e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { diff --git a/soft_delete.go b/soft_delete.go index ba6d2118..6d646288 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9b..41ea5d27 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } From 68bb5379d91a7f7fae4dc65205db66004f515d0c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 09:09:29 +0800 Subject: [PATCH 28/87] Refactor scan into struct --- scan.go | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scan.go b/scan.go index a1cb582e..e83390ca 100644 --- a/scan.go +++ b/scan.go @@ -68,7 +68,11 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re values[idx] = &sql.RawBytes{} } else if len(columns) == 1 { sch = nil - values[idx] = reflectValue.Interface() + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() + } } else { values[idx] = &sql.RawBytes{} } @@ -272,17 +276,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - if update { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) - } else { - elem := reflect.New(reflectValueType) - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) - if isPtr { - db.Statement.ReflectValue.Set(elem) - } else { - db.Statement.ReflectValue.Set(elem.Elem()) - } - } + db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From 530b0a12b4c63bb2dc7abef2934dc8406f1d0f13 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:10:17 +0800 Subject: [PATCH 29/87] Add fast path for ValueOf, ReflectValueOf --- schema/field.go | 70 ++++++++++++++++++++++++++++++------------------- tests/go.mod | 1 + 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/schema/field.go b/schema/field.go index 8c793f93..826680c5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() { } // ValueOf returns field's value and if it is zero - field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { - v = reflect.Indirect(v) - for _, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) - - if !v.IsNil() { - v = v.Elem() + fieldIndex := field.StructField.Index[0] + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) } else { - return nil, true + v = v.Field(-fieldIdx - 1) + + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } } } - } - fv, zero := v.Interface(), v.IsZero() - return fv, zero + fv, zero := v.Interface(), v.IsZero() + return fv, zero + } } if field.Serializer != nil { @@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() { } // ReflectValueOf returns field's reflect value - field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { - v = reflect.Indirect(v) - for idx, fieldIdx := range field.StructField.Index { - if fieldIdx >= 0 { - v = v.Field(fieldIdx) - } else { - v = v.Field(-fieldIdx - 1) + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) + } + default: + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } - if idx < len(field.StructField.Index)-1 { - v = v.Elem() + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } } } + return v } - return v } fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { diff --git a/tests/go.mod b/tests/go.mod index cefe6f96..9e3453b7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,6 +3,7 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 From 43a72b369e670bd91e32784d063608931a59a66e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Feb 2022 22:54:43 +0800 Subject: [PATCH 30/87] Refactor Scan --- scan.go | 104 +++++++++++++++++++++++--------------------------------- 1 file changed, 43 insertions(+), 61 deletions(-) diff --git a/scan.go b/scan.go index e83390ca..d7b58e03 100644 --- a/scan.go +++ b/scan.go @@ -50,58 +50,37 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { - for idx, column := range columns { - if sch == nil { - values[idx] = reflectValue.Interface() - } else if field := sch.LookUpField(column); field != nil && field.Readable { +func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, field := range fields { + if field != nil { values[idx] = field.NewValuePool.Get() defer field.NewValuePool.Put(values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - continue - } + if len(joinFields) == 0 || joinFields[idx][0] == nil { + defer field.Set(db.Statement.Context, reflectValue, values[idx]) } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - sch = nil + } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() } else { values[idx] = reflectValue.Interface() } - } else { - values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) - if sch != nil { - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(db.Statement.Context, reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue) - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(db.Statement.Context, relValue, values[idx]) - } + for idx, joinField := range joinFields { + if joinField[0] != nil { + relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return } + + relValue.Set(reflect.New(relValue.Type().Elem())) } + joinField[1].Set(db.Statement.Context, relValue, values[idx]) } } } @@ -180,7 +159,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { reflectValue = db.Statement.ReflectValue ) - for reflectValue.Kind() == reflect.Interface { + if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } @@ -199,35 +178,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if len(columns) == 1 { - // isPluck + // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } + + // Not Pluck + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } } switch reflectValue.Kind() { @@ -260,7 +242,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflect.New(reflectValueType) } - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if isPtr { @@ -276,7 +258,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest)) From e2e802b837a234ede6dc122dbb26de965e35e55f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Feb 2022 09:28:19 +0800 Subject: [PATCH 31/87] Refactor Scan --- callbacks/create.go | 6 ++++-- scan.go | 29 ++++++++++++++++------------- tests/go.mod | 2 +- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b0964e2b..6e2883f7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: rValLen := stmt.ReflectValue.Len() - stmt.SQL.Grow(rValLen * 18) - values.Values = make([][]interface{}, rValLen) if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) diff --git a/scan.go b/scan.go index d7b58e03..a4243d12 100644 --- a/scan.go +++ b/scan.go @@ -54,10 +54,6 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() - defer field.NewValuePool.Put(values[idx]) - if len(joinFields) == 0 || joinFields[idx][0] == nil { - defer field.Set(db.Statement.Context, reflectValue, values[idx]) - } } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() @@ -70,17 +66,24 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values db.RowsAffected++ db.AddError(rows.Scan(values...)) - for idx, joinField := range joinFields { - if joinField[0] != nil { - relValue := joinField[0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - return - } + for idx, field := range fields { + if field != nil { + if len(joinFields) == 0 || joinFields[idx][0] == nil { + field.Set(db.Statement.Context, reflectValue, values[idx]) + } else { + relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + return + } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set(reflect.New(relValue.Type().Elem())) + } + joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) } - joinField[1].Set(db.Statement.Context, relValue, values[idx]) + + // release data to pool + field.NewValuePool.Put(values[idx]) } } } diff --git a/tests/go.mod b/tests/go.mod index 9e3453b7..c65ea953 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 - github.com/mattn/go-sqlite3 v1.14.11 // indirect + github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 From 996b96e81268335b22faf694dfb4674f84177f17 Mon Sep 17 00:00:00 2001 From: lianghuan Date: Mon, 28 Feb 2022 17:12:09 +0800 Subject: [PATCH 32/87] Add TxConnPoolBeginner and Tx interface --- .gitignore | 1 + finisher_api.go | 3 + interfaces.go | 13 +++ prepare_stmt.go | 7 +- tests/connpool_test.go | 181 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 tests/connpool_test.go diff --git a/.gitignore b/.gitignore index e1b9ecea..45505cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ documents coverage.txt _book .idea +vendor \ No newline at end of file diff --git a/finisher_api.go b/finisher_api.go index f994ec31..5d49ddf9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } } + // FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ @@ -603,6 +604,8 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) } else { err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index 44a85cb5..ed7112f2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,12 +50,25 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxConnPoolBeginner tx conn pool beginner +type TxConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) +} + // TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } +// Tx sql.Tx interface +type Tx interface { + ConnPool + Commit() error + Rollback() error + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr diff --git a/prepare_stmt.go b/prepare_stmt.go index 88bec4e9..94282fad 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,6 +73,9 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } @@ -115,7 +118,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg } type PreparedStmtTX struct { - *sql.Tx + Tx PreparedStmtDB *PreparedStmtDB } @@ -151,7 +154,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 00000000..3713ad7c --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,181 @@ +package tests_test + +import ( + "context" + "database/sql" + "log" + "os" + "reflect" + "testing" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} From 4e523499d191d02e032b126774efd26daa8697a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Mar 2022 16:48:46 +0800 Subject: [PATCH 33/87] Refactor Tx interface --- finisher_api.go | 9 ++++----- interfaces.go | 8 +------- prepare_stmt.go | 3 --- tests/connpool_test.go | 14 ++------------ 4 files changed, 7 insertions(+), 27 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5d49ddf9..4b428a59 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else { + default: err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index ed7112f2..84dc94bb 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,11 +50,6 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxConnPoolBeginner tx conn pool beginner -type TxConnPoolBeginner interface { - BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) -} - // TxCommitter tx committer type TxCommitter interface { Commit() error @@ -64,8 +59,7 @@ type TxCommitter interface { // Tx sql.Tx interface type Tx interface { ConnPool - Commit() error - Rollback() error + TxCommitter StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt } diff --git a/prepare_stmt.go b/prepare_stmt.go index 94282fad..b062b0d6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err - } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { - tx, err := beginner.BeginTx(ctx, opt) - return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 3713ad7c..fbae2294 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -3,15 +3,12 @@ package tests_test import ( "context" "database/sql" - "log" "os" "reflect" "testing" - "time" "gorm.io/driver/mysql" "gorm.io/gorm" - "gorm.io/gorm/logger" . "gorm.io/gorm/utils/tests" ) @@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { // return c.db.BeginTx(ctx, opts) // } // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. -func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) if err != nil { return nil, err @@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Info, - IgnoreRecordNotFoundError: false, - Colorful: true, - }) - - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } From 29a8557384b060bf5d99b4b8824cb75c8a8b9917 Mon Sep 17 00:00:00 2001 From: Cao Manh Dat Date: Thu, 3 Mar 2022 09:17:29 +0700 Subject: [PATCH 34/87] ToSQL should enable SkipDefaultTransaction by default --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 7967b094..aca7cb5e 100644 --- a/gorm.go +++ b/gorm.go @@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) From f961bf1c147113527e486595b0ce342f3c5ba3dd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 12 Mar 2022 22:28:18 +0800 Subject: [PATCH 35/87] chore(deps): bump actions/checkout from 2 to 3 (#5133) Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- .github/workflows/reviewdog.yml | 2 +- .github/workflows/tests.yml | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index bc1add53..0e8aaa60 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index b252dd7a..a6542d57 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 91a0abc9..3e15427c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 @@ -67,7 +67,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache @@ -111,7 +111,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 @@ -154,7 +154,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache uses: actions/cache@v2 From 61b4c31236a8f9792c94240ddb4e236f21bbb9ff Mon Sep 17 00:00:00 2001 From: labulakalia Date: Mon, 14 Mar 2022 21:47:59 +0800 Subject: [PATCH 36/87] fix when index name is "type", parseFieldIndexes will set index TYPE is "TYPE" (#5155) * fix index name is type, parseFieldIndexes will set index TYPE is "TYPE" * check TYPE empty --- schema/index.go | 11 ++++++----- schema/index_test.go | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/schema/index.go b/schema/index.go index 5f775f30..16d096b7 100644 --- a/schema/index.go +++ b/schema/index.go @@ -89,11 +89,12 @@ func parseFieldIndexes(field *Field) (indexes []Index) { k := strings.TrimSpace(strings.ToUpper(v[0])) if k == "INDEX" || k == "UNIQUEINDEX" { var ( - name string - tag = strings.Join(v[1:], ":") - idx = strings.Index(tag, ",") - settings = ParseTagSetting(tag, ",") - length, _ = strconv.Atoi(settings["LENGTH"]) + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") + settings = ParseTagSetting(tagSetting, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) ) if idx == -1 { diff --git a/schema/index_test.go b/schema/index_test.go index bc6bb8b6..3c4582bb 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -18,6 +18,7 @@ type UserIndex struct { Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` + Name7 string `gorm:"index:type"` } func TestParseIndex(t *testing.T) { @@ -78,6 +79,11 @@ func TestParseIndex(t *testing.T) { Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, }, + "type": { + Name: "type", + Type: "", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, + }, } indices := user.ParseIndexes() From 6befa0c947e0107f241663e4312a74bddd0a4ffe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Mar 2022 11:22:25 +0800 Subject: [PATCH 37/87] Refactor preload error check --- callbacks/query.go | 5 +++++ finisher_api.go | 4 ---- tests/count_test.go | 14 +++++++++++--- tests/go.mod | 2 +- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 03798859..04f35c7e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -186,6 +186,11 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") diff --git a/finisher_api.go b/finisher_api.go index 4b428a59..b4d29b71 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -369,10 +369,6 @@ func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() - if len(tx.Statement.Preloads) > 0 { - tx.AddError(ErrPreloadNotAllowed) - return - } if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { diff --git a/tests/count_test.go b/tests/count_test.go index b63a55fc..b71e3de5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -150,8 +150,16 @@ func TestCount(t *testing.T) { Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Preload("Toys", func(db *gorm.DB) *gorm.DB { return db.Table("toys").Select("name") - }).Count(&count12).Error; err != gorm.ErrPreloadNotAllowed { - t.Errorf("should returns preload not allowed error, but got %v", err) + }).Count(&count12).Error; err == nil { + t.Errorf("error should raise when using preload without schema") + } + + var count13 int64 + if err := DB.Model(User{}). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count13).Error; err != nil { + t.Errorf("no error should raise when using count with preload, but got %v", err) } - } diff --git a/tests/go.mod b/tests/go.mod index c65ea953..4ef7fbe2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect + golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 63ac66b56988e1a22c8a3b41d4f1fbf9a8f5d0bc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Mar 2022 11:34:27 +0800 Subject: [PATCH 38/87] Support default tag for time.Time --- schema/field.go | 5 +++++ tests/default_value_test.go | 18 ++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/schema/field.go b/schema/field.go index 826680c5..0d7085a9 100644 --- a/schema/field.go +++ b/schema/field.go @@ -259,6 +259,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } + if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { + if field.DefaultValueInterface, err = now.Parse(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse default value `%v` for field %v", field.DefaultValue, field.Name) + } + } case reflect.Array, reflect.Slice: if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 5e00b154..918f0796 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "gorm.io/gorm" ) @@ -9,12 +10,13 @@ import ( func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model - Email string `gorm:"not null;index:,unique"` - Name string `gorm:"notNull;default:foo"` - Name2 string `gorm:"size:233;not null;default:'foo'"` - Name3 string `gorm:"size:233;notNull;default:''"` - Age int `gorm:"default:18"` - Enabled bool `gorm:"default:true"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Created time.Time `gorm:"default:2000-01-02"` + Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) @@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) { harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to find created data with default data, got %+v", result) } } From f3e2da5ba359f0d672249fc52f54ae41c5a66d3a Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 17 Mar 2022 22:51:56 +0800 Subject: [PATCH 39/87] Added offset when scanning the result back to struct, close #5143 commit 9a2058164d44c98d7b586b87bed1757f89d6fad7 Author: Jinzhu Date: Thu Mar 17 22:34:19 2022 +0800 Refactor #5143 commit c259de21768936428c9d89f7b31afb95b8acb36a Author: Hasan Date: Mon Mar 14 20:04:01 2022 +0545 Update scan_test.go commit 09f127b49151a52fbb8b354a03e6610d4f70262f Author: Hasan Date: Mon Mar 14 19:23:47 2022 +0545 Added test for scanning embedded data into structs commit aeaca493cf412def7813d36fd6a68acc832bf79f Author: Hasan Date: Tue Mar 8 04:08:16 2022 +0600 Added offset when scanning the result back to struct --- scan.go | 22 +++++++++++++++++----- tests/go.mod | 2 +- tests/scan_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/scan.go b/scan.go index a4243d12..89d92354 100644 --- a/scan.go +++ b/scan.go @@ -156,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + selectedColumnsMap = make(map[string]int, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -194,7 +195,18 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { if sch != nil { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field + if curIndex, ok := selectedColumnsMap[column]; ok { + for fieldIndex, selectField := range sch.Fields[curIndex:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } + } + } else { + fields[idx] = field + selectedColumnsMap[column] = idx + } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { diff --git a/tests/go.mod b/tests/go.mod index 4ef7fbe2..9dfa26ff 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jackc/pgx/v4 v4.15.0 // indirect - github.com/jinzhu/now v1.1.4 + github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect diff --git a/tests/scan_test.go b/tests/scan_test.go index 1a188fac..ec1e652f 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -10,6 +10,11 @@ import ( . "gorm.io/gorm/utils/tests" ) +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} @@ -156,3 +161,34 @@ func TestScanRows(t *testing.T) { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) } } + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + + for _, info := range personAddressInfoList { + if info.Person != nil { + if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + } +} From 2990790fbc4c1a3b38a3a7bde15620623264461d Mon Sep 17 00:00:00 2001 From: Mikhail Faraponov <11322032+moredure@users.noreply.github.com> Date: Thu, 17 Mar 2022 16:54:30 +0200 Subject: [PATCH 40/87] Use WriteByte for single byte operations (#5167) Co-authored-by: Mikhail Faraponov --- clause/limit.go | 2 +- clause/where.go | 4 ++-- statement.go | 4 ++-- utils/tests/dummy_dialecter.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index 2082f4d9..184f6025 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) { } if limit.Offset > 0 { if limit.Limit > 0 { - builder.WriteString(" ") + builder.WriteByte(' ') } builder.WriteString("OFFSET ") builder.WriteString(strconv.Itoa(limit.Offset)) diff --git a/clause/where.go b/clause/where.go index 10b6df85..a29401cf 100644 --- a/clause/where.go +++ b/clause/where.go @@ -72,9 +72,9 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { } if wrapInParentheses { - builder.WriteString(`(`) + builder.WriteByte('(') expr.Build(builder) - builder.WriteString(`)`) + builder.WriteByte(')') wrapInParentheses = false } else { expr.Build(builder) diff --git a/statement.go b/statement.go index cb471776..abf646b8 100644 --- a/statement.go +++ b/statement.go @@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.QuoteTo(writer, d) } @@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.DB.Dialector.QuoteTo(writer, d) } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 9543f750..2990c20f 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 - writer.WriteString("`") + writer.WriteByte('`') } writer.WriteByte(v) continue @@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { if continuousBacktick > 0 && !selfQuoted { writer.WriteString("``") } - writer.WriteString("`") + writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { From 9b9ae325bb1fe6e209823d576e70e5e8e6ceccb2 Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Thu, 17 Mar 2022 23:53:31 +0800 Subject: [PATCH 41/87] fix: circular reference save, close #5140 commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144 Author: Jinzhu Date: Thu Mar 17 23:49:21 2022 +0800 Refactor #5140 commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be Author: a631807682 <631807682@qq.com> Date: Sun Mar 13 12:52:08 2022 +0800 test: add test for LoadOrStoreVisitMap commit 9d5c68e41000fd15dea124797dd5f2656bf6b304 Author: chenrui Date: Thu Mar 10 20:33:47 2022 +0800 chore: add more comment commit bfffefb179c883389b72bef8f04469c0a8418043 Author: chenrui Date: Thu Mar 10 20:28:48 2022 +0800 fix: should check values has been saved instead of rel.Name commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494 Author: chenrui Date: Tue Mar 8 17:48:01 2022 +0800 chore: go lint commit fe4715c5bd4ac28950c97dded9848710d8becb88 Author: chenrui Date: Tue Mar 8 17:27:24 2022 +0800 chore: add test comment commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59 Author: chenrui Date: Tue Mar 8 17:22:33 2022 +0800 fix: circular reference save --- callbacks/associations.go | 41 ++++++++++++++++++++++++++++++------- callbacks/helper.go | 30 +++++++++++++++++++++++++++ callbacks/visit_map_test.go | 36 ++++++++++++++++++++++++++++++++ tests/associations_test.go | 41 +++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 2 +- utils/tests/models.go | 14 +++++++++++++ 6 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 callbacks/visit_map_test.go diff --git a/callbacks/associations.go b/callbacks/associations.go index d6fd21de..3b204ab6 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { rv = rv.Addr() } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { @@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) } } } @@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } } @@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + saveAssociations(db, rel, elems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { @@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ return } -func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + var ( selects, omits []string onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) refName = rel.Name + "." + values = rValues.Interface() ) for name, ok := range selectColumns { @@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return db.AddError(tx.Create(values).Error) } + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/callbacks/helper.go b/callbacks/helper.go index a5eb047e..71b67de5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -1,6 +1,7 @@ package callbacks import ( + "reflect" "sort" "gorm.io/gorm" @@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) { return } } + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(vistMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*vistMap)[p]; ok { + return true + } + (*vistMap)[p] = true + } + } + + return +} diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go new file mode 100644 index 00000000..b1fb86db --- /dev/null +++ b/callbacks/visit_map_test.go @@ -0,0 +1,36 @@ +package callbacks + +import ( + "reflect" + "testing" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} diff --git a/tests/associations_test.go b/tests/associations_test.go index 5ce98c7d..32f6525b 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed to preload AppliesToProduct") } } + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + // Save and Updates is the same + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} diff --git a/tests/tests_test.go b/tests/tests_test.go index 11b6f067..08f4f193 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index c84f9cae..22e8e659 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -80,3 +80,17 @@ type Order struct { Coupon *Coupon CouponID string } + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} + +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +} From c2e36ebe62a0e79649aff1a539b39ace86bc6bab Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Fri, 18 Mar 2022 01:07:49 +0800 Subject: [PATCH 42/87] fix: soft delete for join, close #5132 commit a83023bdfc0dc6eaccc6704b64ff6436c2fe7725 Author: Jinzhu Date: Fri Mar 18 01:05:25 2022 +0800 Refactor #5132 commit 8559f51102c01be6c19913c0bc3a5771721ff1f5 Author: chenrui Date: Mon Mar 7 20:33:12 2022 +0800 fix: should add deleted_at exprs for every joins commit 2b7a1bdcf3eff9d23253173d21e73c1f056f9be4 Author: chenrui Date: Mon Mar 7 14:46:48 2022 +0800 test: move debug flag commit ce13a2a7bc50d2c23678806acf65dbd589827c77 Author: chenrui Date: Mon Mar 7 14:39:56 2022 +0800 fix: soft delete for join.on --- callbacks/query.go | 38 ++++++++++++++++++++++++++------------ tests/helper_test.go | 5 +++++ tests/joins_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 04f35c7e..c4c80406 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -145,19 +145,33 @@ func BuildQuerySQL(db *gorm.DB) { } } - if join.On != nil { - onStmt := gorm.Statement{Table: tableAliasName, DB: db} - join.On.Build(&onStmt) - onSQL := onStmt.SQL.String() - vars := onStmt.Vars - for idx, v := range onStmt.Vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) } - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } } joins = append(joins, clause.Join{ @@ -172,8 +186,8 @@ func BuildQuerySQL(db *gorm.DB) { } } - db.Statement.Joins = nil db.Statement.AddClause(clause.From{Joins: joins}) + db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } diff --git a/tests/helper_test.go b/tests/helper_test.go index eee34e99..7ee2a576 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -19,6 +19,7 @@ type Config struct { Team int Languages int Friends int + NamedPet bool } func GetUser(name string, config Config) *User { @@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } + if config.NamedPet { + user.NamedPet = &Pet{Name: name + "_namepet"} + } + return &user } diff --git a/tests/joins_test.go b/tests/joins_test.go index 4c9cffae..0f02f3f9 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -200,3 +200,34 @@ func TestJoinCount(t *testing.T) { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } } + +func TestJoinWithSoftDeleted(t *testing.T) { + DB = DB.Debug() + + user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) + DB.Create(&user) + + var user1 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID) + if user1.NamedPet == nil || user1.Account.ID == 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user1) + } + + // Account should empty + DB.Delete(&user1.Account) + + var user2 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID) + if user2.NamedPet == nil || user2.Account.ID != 0 { + t.Fatalf("joins Account should not empty:%v", user2) + } + + // NamedPet should empty + DB.Delete(&user1.NamedPet) + + var user3 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID) + if user3.NamedPet != nil || user2.Account.ID != 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user2) + } +} From 5431da8caf09ad19256170df17e2e75eb541f4a5 Mon Sep 17 00:00:00 2001 From: chenrui <631807682@qq.com> Date: Fri, 18 Mar 2022 13:38:46 +0800 Subject: [PATCH 43/87] fix: preload panic when model and dest different close #5130 commit e8307b5ef5273519a32cd8e4fd29250d1c277f6e Author: Jinzhu Date: Fri Mar 18 13:37:22 2022 +0800 Refactor #5130 commit 40cbba49f374c9bae54f80daee16697ae45e905b Author: chenrui Date: Sat Mar 5 17:36:56 2022 +0800 test: fix test fail commit 66d3f078291102a30532b6a9d97c757228a9b543 Author: chenrui Date: Sat Mar 5 17:29:09 2022 +0800 test: drop table and auto migrate commit 7cbf019a930019476a97ac7ac0f5fc186e8d5b42 Author: chenrui Date: Sat Mar 5 15:27:45 2022 +0800 fix: preload panic when model and dest different --- callbacks/preload.go | 56 ++++++++++++++++++------------------- callbacks/query.go | 15 ++++++++-- chainable_api.go | 5 +++- tests/preload_suits_test.go | 2 +- tests/preload_test.go | 18 ++++++++++++ 5 files changed, 63 insertions(+), 33 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 2363a8ca..888f832d 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -10,10 +10,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( - reflectValue = db.Statement.ReflectValue - tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field @@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload inlineConds []interface{} ) - db.Statement.Settings.Range(func(k, v interface{}) bool { - tx.Statement.Settings.Store(k, v) - return true - }) - if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) @@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { - return + return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(db.Statement.Context, joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { - return + return nil } } @@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } } fieldValues := make([]interface{}, len(relForeignFields)) @@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) default: - rel.Field.Set(db.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) } } } @@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(db.Statement.Context, elem) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", - elem.Interface())) - continue + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(db.Statement.Context, data) + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(db.Statement.Context, data, elem.Interface()) + rel.Field.Set(tx.Statement.Context, data, elem.Interface()) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) } else { - rel.Field.Set(db.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) } } } } + + return tx.Error } diff --git a/callbacks/query.go b/callbacks/query.go index c4c80406..6ba3dd38 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -237,9 +237,20 @@ func Preload(db *gorm.DB) { } sort.Strings(preloadNames) + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + for _, name := range preloadNames { - if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { - preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) + if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/chainable_api.go b/chainable_api.go index 173479d3..38ad5cde 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - } else { + } else if name != "" { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" } return } diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 0ef8890b..b5b6a70f 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { diff --git a/tests/preload_test.go b/tests/preload_test.go index adb54ee1..cb4343ec 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) { } wg.Wait() } + +func TestPreloadWithDiffModel(t *testing.T) { + user := *GetUser("preload_with_diff_model", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result struct { + Something string + User + } + + DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( + "users.*, 'yo' as something").First(&result, "name = ?", user.Name) + + CheckUser(t, user, result.User) +} From e6f7da0e0dbc193df883f799a4650d0a86507376 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Mar 2022 14:30:30 +0800 Subject: [PATCH 44/87] Support Variable Relation --- schema/relationship.go | 6 +++++- schema/relationship_test.go | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/schema/relationship.go b/schema/relationship.go index eae8ab0b..b5100897 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -416,6 +416,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } else { var primaryFields []*Field + var primarySchemaName = primarySchema.Name + if primarySchemaName == "" { + primarySchemaName = relation.FieldSchema.Name + } if len(relation.primaryKeys) > 0 { for _, primaryKey := range relation.primaryKeys { @@ -428,7 +432,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } for _, primaryField := range primaryFields { - lookUpName := primarySchema.Name + primaryField.Name + lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 40ffc324..6fffbfcb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestVariableRelation(t *testing.T) { + var result struct { + User + } + + checkStructRelation(t, &result, Relation{ + Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account", + References: []Reference{ + {"ID", "", "UserID", "Account", "", true}, + }, + }) + + checkStructRelation(t, &result, Relation{ + Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company", + References: []Reference{ + {"ID", "Company", "CompanyID", "", "", false}, + }, + }) +} + func TestSameForeignKey(t *testing.T) { type UserAux struct { gorm.Model From 3c00980e01a6a16095b9fafddedd3217ad4b7357 Mon Sep 17 00:00:00 2001 From: ag9920 Date: Fri, 18 Mar 2022 17:12:17 +0800 Subject: [PATCH 45/87] fix: serializer use default valueOf in assignInterfacesToValue, close #5168 commit 58e1b2bffbc216f2862d040fb545a8a486e473b6 Author: Jinzhu Date: Fri Mar 18 17:06:43 2022 +0800 Refactor #5168 commit fb9233011d209174e8223e970f0f732412852908 Author: ag9920 Date: Thu Mar 17 21:23:28 2022 +0800 fix: serializer use default valueOf in assignInterfacesToValue --- schema/field.go | 80 ++++++++++++++++++++++------------------ tests/joins_test.go | 2 - tests/serializer_test.go | 51 ++++++++++++++++++++++++- 3 files changed, 95 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 0d7085a9..45ec66e1 100644 --- a/schema/field.go +++ b/schema/field.go @@ -435,39 +435,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // Setup NewValuePool - var fieldValue = reflect.New(field.FieldType).Interface() - if field.Serializer != nil { - field.NewValuePool = &sync.Pool{ - New: func() interface{} { - return &serializer{ - Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), - } - }, - } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } - } - - if field.NewValuePool == nil { - field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) - } + field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] @@ -512,7 +480,7 @@ func (field *Field) setupValuerAndSetter() { s = field.Serializer } - return serializer{ + return &serializer{ Field: field, SerializeValuer: s, Destination: v, @@ -943,7 +911,9 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { - if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) @@ -959,3 +929,43 @@ func (field *Field) setupValuerAndSetter() { } } } + +func (field *Field) setupNewValuePool() { + var fieldValue = reflect.New(field.FieldType).Interface() + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, + } + } else if _, ok := fieldValue.(sql.Scanner); !ok { + field.setupDefaultNewValuePool() + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } +} + +func (field *Field) setupDefaultNewValuePool() { + // set default NewValuePool + switch field.IndirectFieldType.Kind() { + case reflect.String: + field.NewValuePool = stringPool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.NewValuePool = intPool + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.NewValuePool = uintPool + case reflect.Float32, reflect.Float64: + field.NewValuePool = floatPool + case reflect.Bool: + field.NewValuePool = boolPool + default: + if field.IndirectFieldType == TimeReflectType { + field.NewValuePool = timePool + } + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 0f02f3f9..bb5352ef 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -202,8 +202,6 @@ func TestJoinCount(t *testing.T) { } func TestJoinWithSoftDeleted(t *testing.T) { - DB = DB.Debug() - user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) DB.Create(&user) diff --git a/tests/serializer_test.go b/tests/serializer_test.go index a8a4e28f..ce60280e 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -42,7 +42,7 @@ func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst re case string: *es = EncryptedString(strings.TrimPrefix(value, "hello")) default: - return fmt.Errorf("unsupported data %v", dbValue) + return fmt.Errorf("unsupported data %#v", dbValue) } return nil } @@ -83,4 +83,53 @@ func TestSerializer(t *testing.T) { } AssertEqual(t, result, data) + +} + +func TestSerializerAssignFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + //update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) } From d402765f694ade8fd3a0da1b7a2f9d2fa4453957 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Fri, 18 Mar 2022 20:11:23 +0800 Subject: [PATCH 46/87] test: fix utils.AssertEqual (#5172) --- tests/query_test.go | 4 +++- utils/tests/utils.go | 29 +++++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/query_test.go b/tests/query_test.go index 6542774a..af2b8d4b 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -583,7 +583,9 @@ func TestPluck(t *testing.T) { if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { t.Errorf("got error when pluck name: %v", err) } - AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + + sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] }) + AssertEqual(t, names, names2) var ids []int if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 817e4b0b..661d727f 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -83,20 +83,22 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } if reflect.ValueOf(got).Kind() == reflect.Struct { - if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { - exported := false - for i := 0; i < reflect.ValueOf(got).NumField(); i++ { - if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { - exported = true - field := reflect.ValueOf(got).Field(i) - t.Run(fieldStruct.Name, func(t *testing.T) { - AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) - }) + if reflect.ValueOf(expect).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } } - } - if exported { - return + if exported { + return + } } } } @@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() isEqual() + } else { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return } } } From 540b47571a2c74134c2a8eb02d5a8ef70b0bf8d6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Mar 2022 20:57:33 +0800 Subject: [PATCH 47/87] Fix update select clause with before/after expressions, close #5164 --- chainable_api.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 38ad5cde..68b4d1aa 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } - delete(tx.Statement.Clauses, "SELECT") + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } case string: if strings.Count(v, "?") >= len(args) && len(args) > 0 { tx.Statement.AddClause(clause.Select{ @@ -123,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } - delete(tx.Statement.Clauses, "SELECT") + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) From 0097b39a77b9573d63f89c22f3cea0aae103a77f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 20 Mar 2022 08:55:08 +0800 Subject: [PATCH 48/87] Should ignore error when parsing default value for time, close #5176 --- schema/field.go | 4 ++-- tests/go.mod | 2 +- tests/postgres_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 45ec66e1..96291816 100644 --- a/schema/field.go +++ b/schema/field.go @@ -260,8 +260,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { - if field.DefaultValueInterface, err = now.Parse(field.DefaultValue); err != nil { - schema.err = fmt.Errorf("failed to parse default value `%v` for field %v", field.DefaultValue, field.Name) + if t, err := now.Parse(field.DefaultValue); err == nil { + field.DefaultValueInterface = t } } case reflect.Array, reflect.Slice: diff --git a/tests/go.mod b/tests/go.mod index 9dfa26ff..17e5d350 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,7 +14,7 @@ require ( gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlserver v1.3.1 - gorm.io/gorm v1.23.1 + gorm.io/gorm v1.23.3 ) replace gorm.io/gorm => ../ diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 418b713e..66b988c3 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -19,7 +19,7 @@ func TestPostgres(t *testing.T) { Name string `gorm:"check:name_checker,name <> ''"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` - UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` Things pq.StringArray `gorm:"type:text[]"` } From 2d5cb997ed4d0e8f53fa1662111ad2cb053caf9c Mon Sep 17 00:00:00 2001 From: Jin Date: Sun, 20 Mar 2022 09:02:45 +0800 Subject: [PATCH 49/87] style: fix linter check for NamingStrategy and onConflictOption (#5174) --- callbacks/associations.go | 4 ++-- schema/naming.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3b204ab6..644ef185 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { @@ -349,7 +349,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Val var ( selects, omits []string - onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) refName = rel.Name + "." values = rValues.Interface() ) diff --git a/schema/naming.go b/schema/naming.go index 47a2b363..a258beed 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -85,9 +85,9 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formattedName := strings.Replace(strings.Join([]string{ + formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, - }, "_"), ".", "_", -1) + }, "_"), ".", "_") if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() From d66f37ad322cbda02bb873b5b2f1093296672b49 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 21 Mar 2022 10:50:14 +0800 Subject: [PATCH 50/87] Add Go 1.18 --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3e15427c..ad4c9917 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -39,7 +39,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -83,7 +83,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -125,7 +125,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} From a7b3b5956fad0ae536147a19e89300af0462d74d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 22 Mar 2022 22:42:36 +0800 Subject: [PATCH 51/87] Fix hooks order, close https://github.com/go-gorm/gorm.io/pull/519 --- callbacks/create.go | 17 ++++++++++------- callbacks/update.go | 16 ++++++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index 6e2883f7..0a43cacb 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/utils" ) +// BeforeCreate before create hooks func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) { } } +// Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") @@ -146,22 +148,23 @@ func Create(config *Config) func(db *gorm.DB) { } } +// AfterCreate after create hooks func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { - if db.Statement.Schema.AfterSave { - if i, ok := value.(AfterSaveInterface); ok { - called = true - db.AddError(i.AfterSave(tx)) - } - } - if db.Statement.Schema.AfterCreate { if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } return called }) } diff --git a/callbacks/update.go b/callbacks/update.go index da03261e..1964973b 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } } +// BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) { } } +// Update update hook func Update(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") @@ -99,9 +101,17 @@ func Update(config *Config) func(db *gorm.DB) { } } +// AfterUpdate after update hooks func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { called = true @@ -109,12 +119,6 @@ func AfterUpdate(db *gorm.DB) { } } - if db.Statement.Schema.AfterUpdate { - if i, ok := value.(AfterUpdateInterface); ok { - called = true - db.AddError(i.AfterUpdate(tx)) - } - } return called }) } From f92e6747cb12d5a5bc2bf7e0d76cb8e5f69cd637 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 23 Mar 2022 17:24:25 +0800 Subject: [PATCH 52/87] Handle field set value error --- callbacks/associations.go | 14 +++++++------- callbacks/create.go | 18 +++++++++--------- callbacks/preload.go | 14 +++++++------- callbacks/update.go | 2 +- scan.go | 4 ++-- schema/field.go | 5 +++-- statement.go | 8 ++++---- tests/go.mod | 2 +- 8 files changed, 34 insertions(+), 33 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 644ef185..fd3141cf 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -159,9 +159,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) - ref.ForeignKey.Set(db.Statement.Context, f, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } @@ -193,9 +193,9 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) - ref.ForeignKey.Set(db.Statement.Context, elem, pv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) } } @@ -261,12 +261,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) } else { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) - ref.ForeignKey.Set(db.Statement.Context, joinValue, fv) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } } joins = reflect.Append(joins, joinValue) diff --git a/callbacks/create.go b/callbacks/create.go index 0a43cacb..e94b7eca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -121,7 +121,7 @@ func Create(config *Config) func(db *gorm.DB) { _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -133,7 +133,7 @@ func Create(config *Config) func(db *gorm.DB) { } if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -141,7 +141,7 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Struct: _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } @@ -227,13 +227,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(stmt.Context, rv, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, rv, curTime) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } @@ -267,13 +267,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.Context, stmt.ReflectValue, curTime) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } diff --git a/callbacks/preload.go b/callbacks/preload.go index 888f832d..ea2570ba 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -123,17 +123,17 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } @@ -158,12 +158,12 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(tx.Statement.Context, data, elem.Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { - rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 1964973b..01f40509 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]) + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) } } } diff --git a/scan.go b/scan.go index 89d92354..42642ec6 100644 --- a/scan.go +++ b/scan.go @@ -69,7 +69,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values for idx, field := range fields { if field != nil { if len(joinFields) == 0 || joinFields[idx][0] == nil { - field.Set(db.Statement.Context, reflectValue, values[idx]) + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { @@ -79,7 +79,7 @@ func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values relValue.Set(reflect.New(relValue.Type().Elem())) } - joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]) + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool diff --git a/schema/field.go b/schema/field.go index 96291816..3b5cc5c5 100644 --- a/schema/field.go +++ b/schema/field.go @@ -12,6 +12,7 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) @@ -567,8 +568,8 @@ func (field *Field) setupValuerAndSetter() { if v, err = valuer.Value(); err == nil { err = setter(ctx, value, v) } - } else { - return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) } } diff --git a/statement.go b/statement.go index abf646b8..9fcee09c 100644 --- a/statement.go +++ b/statement.go @@ -562,7 +562,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(stmt.Context, destValue, value) + stmt.AddError(field.Set(stmt.Context, destValue, value)) default: stmt.AddError(ErrInvalidData) } @@ -572,10 +572,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) } } else { - field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -583,7 +583,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.Context, stmt.ReflectValue, value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) } } else { stmt.AddError(ErrInvalidField) diff --git a/tests/go.mod b/tests/go.mod index 17e5d350..b85ebdad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect + golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 9a4d10be64738f0c1f7a86841d56e2fe3165e3f0 Mon Sep 17 00:00:00 2001 From: Jin Date: Thu, 24 Mar 2022 09:31:58 +0800 Subject: [PATCH 53/87] style: fix coding typo (#5184) --- migrator/column_type.go | 2 +- tests/main_test.go | 6 ++---- tests/migrate_test.go | 2 +- tests/sql_builder_test.go | 10 +++++----- tests/upsert_test.go | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/migrator/column_type.go b/migrator/column_type.go index cc1331b9..c6fdd6b2 100644 --- a/migrator/column_type.go +++ b/migrator/column_type.go @@ -44,7 +44,7 @@ func (ct ColumnType) DatabaseTypeName() string { return ct.SQLColumnType.DatabaseTypeName() } -// ColumnType returns the database type of the column. lke `varchar(16)` +// ColumnType returns the database type of the column. like `varchar(16)` func (ct ColumnType) ColumnType() (columnType string, ok bool) { return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid } diff --git a/tests/main_test.go b/tests/main_test.go index 5b8c7dbb..997714b9 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) { func TestSetAndGet(t *testing.T) { if value, ok := DB.Set("hello", "world").Get("hello"); !ok { t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } + } else if value.(string) != "world" { + t.Errorf("Set value should not be changed") } if _, ok := DB.Get("non_existing"); ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 94f562b4..f72c4c08 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) { DB.Migrator().DropTable("new_table_structs") if DB.Migrator().HasTable(&NewTableStruct{}) { - t.Fatal("should not found droped table") + t.Fatal("should not found dropped table") } } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index bc917c32..a7630271 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -360,7 +360,7 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } @@ -426,13 +426,13 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } } -// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. func assertEqualSQL(t *testing.T, expected string, actually string) { t.Helper() @@ -440,7 +440,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { expected = replaceQuoteInSQL(expected) actually = replaceQuoteInSQL(actually) - // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. + // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update. updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) @@ -462,7 +462,7 @@ func replaceQuoteInSQL(sql string) string { // convert single quote into double quote sql = strings.ReplaceAll(sql, `'`, `"`) - // convert dialect speical quote into double quote + // convert dialect special quote into double quote switch DB.Dialector.Name() { case "postgres": sql = strings.ReplaceAll(sql, `"`, `"`) diff --git a/tests/upsert_test.go b/tests/upsert_test.go index c5d19605..f90c4518 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) { tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) if err := tx.Error; err != nil { - t.Fatalf("failed to update user,missing where condtion,err=%+v", err) + t.Fatalf("failed to update user,missing where condition,err=%+v", err) } if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { From 3d7019a7c236890aae9716335c7d5b6dae116d17 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 24 Mar 2022 09:34:06 +0800 Subject: [PATCH 54/87] fix: throw err if association model miss primary key (#5187) --- association.go | 21 +++++++++++++++------ tests/associations_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/association.go b/association.go index 09e79ca6..dc731ff8 100644 --- a/association.go +++ b/association.go @@ -187,8 +187,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) @@ -199,8 +202,11 @@ func (association *Association) Delete(values ...interface{}) error { tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) @@ -229,8 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { } _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) diff --git a/tests/associations_test.go b/tests/associations_test.go index 32f6525b..bc3dac55 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -261,3 +261,27 @@ func TestSaveHasManyCircularReference(t *testing.T) { children, parent.Children) } } + +func TestAssociationError(t *testing.T) { + DB = DB.Debug() + user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) + DB.Create(&user) + + var user1 User + DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1) + + var emptyUser User + var err error + // belongs to + err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has many + err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has one + err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // many to many + err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) +} From 6d40a8343249e208aa79b938a7b0939a631b6b74 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Thu, 24 Mar 2022 16:30:14 +0800 Subject: [PATCH 55/87] Update README.md add gorm gen --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a3eabe39..312a3a59 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) +* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) ## Contributing From 6c827ff2e3ffa0e8b7e4c598031f6af8124a7357 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Mar 2022 19:55:05 +0800 Subject: [PATCH 56/87] chore(deps): bump actions/cache from 2 to 3 (#5196) Bumps [actions/cache](https://github.com/actions/cache) from 2 to 3. - [Release notes](https://github.com/actions/cache/releases) - [Commits](https://github.com/actions/cache/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad4c9917..8194e609 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -71,7 +71,7 @@ jobs: - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -114,7 +114,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -157,7 +157,7 @@ jobs: uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} From 9dd6ed9c65bcf95e4a4298bcdf1f26670778ba76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Mar 2022 18:14:29 +0800 Subject: [PATCH 57/87] Scan with Rows interface --- interfaces.go | 10 ++++++++++ scan.go | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/interfaces.go b/interfaces.go index 84dc94bb..32d49605 100644 --- a/interfaces.go +++ b/interfaces.go @@ -72,3 +72,13 @@ type Valuer interface { type GetDBConnector interface { GetDBConn() (*sql.DB, error) } + +// Rows rows interface +type Rows interface { + Columns() ([]string, error) + ColumnTypes() ([]*sql.ColumnType, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} diff --git a/scan.go b/scan.go index 42642ec6..c8da13da 100644 --- a/scan.go +++ b/scan.go @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -99,7 +99,7 @@ const ( ) // Scan scan rows into db statement -func Scan(rows *sql.Rows, db *DB, mode ScanMode) { +func Scan(rows Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() values = make([]interface{}, len(columns)) From ea8509b77704b152380f8097c59e5ae3b57428bb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Mar 2022 18:48:06 +0800 Subject: [PATCH 58/87] Use defer to close rows to avoid scan panic leak rows --- callbacks/create.go | 4 +++- callbacks/query.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index e94b7eca..0fe1dc93 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -84,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., ) if db.AddError(err) == nil { + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, mode) - db.AddError(rows.Close()) } return diff --git a/callbacks/query.go b/callbacks/query.go index 6ba3dd38..6eda52ef 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,8 +20,10 @@ func Query(db *gorm.DB) { db.AddError(err) return } + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, 0) - db.AddError(rows.Close()) } } } From 8333844f7112192ebd203992a67adf01b51ee8a0 Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Thu, 31 Mar 2022 20:57:20 +0800 Subject: [PATCH 59/87] fix variable shadowing (#5212) Co-authored-by: Shenao Zhang --- gorm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index aca7cb5e..6a6bb032 100644 --- a/gorm.go +++ b/gorm.go @@ -124,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { for _, opt := range opts { if opt != nil { - if err := opt.Apply(config); err != nil { - return nil, err + if applyErr := opt.Apply(config); applyErr != nil { + return nil, applyErr } defer func(opt Option) { if errr := opt.AfterInitialize(db); errr != nil { From cd0315334b0fe555500d6f1870c566093d7daa33 Mon Sep 17 00:00:00 2001 From: Goxiaoy Date: Fri, 1 Apr 2022 08:33:39 +0800 Subject: [PATCH 60/87] fix: context missing in association (#5214) --- association.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/association.go b/association.go index dc731ff8..35e10ddd 100644 --- a/association.go +++ b/association.go @@ -502,7 +502,7 @@ func (association *Association) buildCondition() *DB { if association.Relationship.JoinTable != nil { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } From f7b52bb649ba803ec149a06fec9e9da7b311d36e Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Fri, 1 Apr 2022 08:35:16 +0800 Subject: [PATCH 61/87] unify db receiver name (#5215) Co-authored-by: Shenao Zhang --- finisher_api.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index b4d29b71..aa8e2b5a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -207,7 +207,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return tx } -func (tx *DB) assignInterfacesToValue(values ...interface{}) { +func (db *DB) assignInterfacesToValue(values ...interface{}) { for _, value := range values { switch v := value.(type) { case []clause.Expression: @@ -215,40 +215,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { - tx.assignInterfacesToValue(andCond.Exprs) + db.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: - if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } default: - if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(tx.Statement.Context, reflectValue); !isZero { - if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.Context, tx.Statement.ReflectValue, v)) + if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { + if field := db.Statement.Schema.LookUpField(f.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) } } } } } } else if len(values) > 0 { - if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } return } From 9144969c83829d2f14049a6e4882f785a90b6cf9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 2 Apr 2022 17:17:47 +0800 Subject: [PATCH 62/87] Allow to use tag to disable auto create/update time --- schema/field.go | 4 ++-- tests/associations_test.go | 1 - tests/go.mod | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index 3b5cc5c5..77521ad3 100644 --- a/schema/field.go +++ b/schema/field.go @@ -275,7 +275,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = DataType(dataTyper.GormDataType()) } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoCreateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { @@ -287,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoUpdateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { diff --git a/tests/associations_test.go b/tests/associations_test.go index bc3dac55..e729e979 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -263,7 +263,6 @@ func TestSaveHasManyCircularReference(t *testing.T) { } func TestAssociationError(t *testing.T) { - DB = DB.Debug() user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) DB.Create(&user) diff --git a/tests/go.mod b/tests/go.mod index b85ebdad..fc6600b7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 // indirect + golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 From 38a24606da3cd1e312644ef5f8d71e4d0d35554a Mon Sep 17 00:00:00 2001 From: huangcheng1 Date: Sat, 2 Apr 2022 17:27:53 +0800 Subject: [PATCH 63/87] fix: tables lost when joins exists in from clause, close #5218 commit 7f6a603afa26820e187489b5203f93adc513687c Author: Jinzhu Date: Sat Apr 2 17:26:48 2022 +0800 Refactor #5218 commit 95d00e6ff2668233f3eca98aa4917291e3d869bd Author: huangcheng1 Date: Fri Apr 1 16:30:27 2022 +0800 fix: tables lost when joins exists in from clause --- callbacks/query.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 6eda52ef..fb2bb37a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,12 +96,12 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - joins := []clause.Join{} - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v } - if len(db.Statement.Joins) != 0 || len(joins) != 0 { + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { @@ -111,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) { for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { @@ -176,19 +176,19 @@ func BuildQuerySQL(db *gorm.DB) { } } - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) } else { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } - db.Statement.AddClause(clause.From{Joins: joins}) + db.Statement.AddClause(fromClause) db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) From 81c4024232c35c3d49907f3ae77c2857a1dd7f63 Mon Sep 17 00:00:00 2001 From: Hasan Date: Thu, 7 Apr 2022 21:56:41 +0600 Subject: [PATCH 64/87] Offset issue resolved for scanning results back into struct (#5227) --- scan.go | 2 +- tests/scan_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index c8da13da..2ce6bd28 100644 --- a/scan.go +++ b/scan.go @@ -196,7 +196,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { if curIndex, ok := selectedColumnsMap[column]; ok { - for fieldIndex, selectField := range sch.Fields[curIndex:] { + for fieldIndex, selectField := range sch.Fields[curIndex+1:] { if selectField.DBName == column && selectField.Readable { selectedColumnsMap[column] = curIndex + fieldIndex + 1 fields[idx] = selectField diff --git a/tests/scan_test.go b/tests/scan_test.go index ec1e652f..425c0a29 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -184,11 +184,34 @@ func TestScanToEmbedded(t *testing.T) { t.Errorf("Failed to run join query, got error: %v", err) } + personMatched := false + addressMatched := false + for _, info := range personAddressInfoList { - if info.Person != nil { - if info.Person.ID == person1.ID && info.Person.Name != person1.Name { + if info.Person == nil { + t.Fatalf("Failed, expected not nil, got person nil") + } + if info.Address == nil { + t.Fatalf("Failed, expected not nil, got address nil") + } + if info.Person.ID == person1.ID { + personMatched = true + if info.Person.Name != person1.Name { t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) } } + if info.Address.ID == address1.ID { + addressMatched = true + if info.Address.Name != address1.Name { + t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name) + } + } + } + + if !personMatched { + t.Errorf("Failed, no person matched") + } + if !addressMatched { + t.Errorf("Failed, no address matched") } } From 0729261b627d0f73ab0e9bccc5b548d5e55fae88 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Apr 2022 14:23:25 +0800 Subject: [PATCH 65/87] Support double ptr for Save --- finisher_api.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index aa8e2b5a..5e4c3c5a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Dest = value reflectValue := reflect.Indirect(reflect.ValueOf(value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { From 5c9ef9a8435334236662009c21d95c4bcc15a532 Mon Sep 17 00:00:00 2001 From: Naveen <172697+naveensrinivasan@users.noreply.github.com> Date: Sat, 9 Apr 2022 20:38:43 -0500 Subject: [PATCH 66/87] Set permissions for GitHub actions (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restrict the GitHub token permissions only to the required ones; this way, even if the attackers will succeed in compromising your workflow, they won’t be able to do much. - Included permissions for the action. https://github.com/ossf/scorecard/blob/main/docs/checks.md#token-permissions https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#permissions https://docs.github.com/en/actions/using-jobs/assigning-permissions-to-jobs [Keeping your GitHub Actions and workflows secure Part 1: Preventing pwn requests](https://securitylab.github.com/research/github-actions-preventing-pwn-requests/) Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 6 ++++++ .github/workflows/missing_playground.yml | 6 ++++++ .github/workflows/stale.yml | 6 ++++++ .github/workflows/tests.yml | 3 +++ 4 files changed, 21 insertions(+) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 868bcc34..327a70f6 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 3efc90f7..15d3850f 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index e0be186f..c5e0d7ab 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -3,8 +3,14 @@ on: schedule: - cron: "0 2 * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8194e609..8bfb2332 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,6 +8,9 @@ on: branches-ignore: - 'gh-pages' +permissions: + contents: read + jobs: # Label of the container job sqlite: From 41bef26f137fb1633b937482011c2266b4123a41 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Apr 2022 21:37:02 +0800 Subject: [PATCH 67/87] Remove shared sync pool for Scanner compatibility --- schema/field.go | 23 ----------------------- schema/pool.go | 45 +-------------------------------------------- tests/go.mod | 11 +++++------ 3 files changed, 6 insertions(+), 73 deletions(-) diff --git a/schema/field.go b/schema/field.go index 77521ad3..fd8b2e6a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -932,7 +932,6 @@ func (field *Field) setupValuerAndSetter() { } func (field *Field) setupNewValuePool() { - var fieldValue = reflect.New(field.FieldType).Interface() if field.Serializer != nil { field.NewValuePool = &sync.Pool{ New: func() interface{} { @@ -942,31 +941,9 @@ func (field *Field) setupNewValuePool() { } }, } - } else if _, ok := fieldValue.(sql.Scanner); !ok { - field.setupDefaultNewValuePool() } if field.NewValuePool == nil { field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) } } - -func (field *Field) setupDefaultNewValuePool() { - // set default NewValuePool - switch field.IndirectFieldType.Kind() { - case reflect.String: - field.NewValuePool = stringPool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.NewValuePool = intPool - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.NewValuePool = uintPool - case reflect.Float32, reflect.Float64: - field.NewValuePool = floatPool - case reflect.Bool: - field.NewValuePool = boolPool - default: - if field.IndirectFieldType == TimeReflectType { - field.NewValuePool = timePool - } - } -} diff --git a/schema/pool.go b/schema/pool.go index f5c73153..fa62fe22 100644 --- a/schema/pool.go +++ b/schema/pool.go @@ -3,54 +3,11 @@ package schema import ( "reflect" "sync" - "time" ) // sync pools var ( - normalPool sync.Map - stringPool = &sync.Pool{ - New: func() interface{} { - var v string - ptrV := &v - return &ptrV - }, - } - intPool = &sync.Pool{ - New: func() interface{} { - var v int64 - ptrV := &v - return &ptrV - }, - } - uintPool = &sync.Pool{ - New: func() interface{} { - var v uint64 - ptrV := &v - return &ptrV - }, - } - floatPool = &sync.Pool{ - New: func() interface{} { - var v float64 - ptrV := &v - return &ptrV - }, - } - boolPool = &sync.Pool{ - New: func() interface{} { - var v bool - ptrV := &v - return &ptrV - }, - } - timePool = &sync.Pool{ - New: func() interface{} { - var v time.Time - ptrV := &v - return &ptrV - }, - } + normalPool sync.Map poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ New: func() interface{} { diff --git a/tests/go.mod b/tests/go.mod index fc6600b7..3ac4633e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,15 +5,14 @@ go 1.14 require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v4 v4.15.0 // indirect github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.4 + github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect - gorm.io/driver/mysql v1.3.2 - gorm.io/driver/postgres v1.3.1 + golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 // indirect + gorm.io/driver/mysql v1.3.3 + gorm.io/driver/postgres v1.3.4 gorm.io/driver/sqlite v1.3.1 - gorm.io/driver/sqlserver v1.3.1 + gorm.io/driver/sqlserver v1.3.2 gorm.io/gorm v1.23.3 ) From 74e07b049c446bd0f1102c9f7c164558648850bd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 Apr 2022 22:07:40 +0800 Subject: [PATCH 68/87] Serializer unixtime support ptr of int --- schema/serializer.go | 8 ++++---- tests/serializer_test.go | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/schema/serializer.go b/schema/serializer.go index 09da6d9e..758a6421 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -108,8 +108,8 @@ type UnixSecondSerializer struct { // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { t := sql.NullTime{} - if err = t.Scan(dbValue); err == nil { - err = field.Set(ctx, dst, t.Time) + if err = t.Scan(dbValue); err == nil && t.Valid { + err = field.Set(ctx, dst, t.Time.Unix()) } return @@ -118,8 +118,8 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.ValueOf(v).Int(), 0) + case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/tests/serializer_test.go b/tests/serializer_test.go index ce60280e..ee14841a 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -21,6 +21,7 @@ type SerializerStruct struct { Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } @@ -58,6 +59,7 @@ func TestSerializer(t *testing.T) { } createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt := createdAt.Unix() data := SerializerStruct{ Name: []byte("jinzhu"), @@ -65,6 +67,7 @@ func TestSerializer(t *testing.T) { Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, JobInfo: Job{ Title: "programmer", Number: 9920, From 6aa6d37fc47a433510ac05e2f01eb33e57d7cb6c Mon Sep 17 00:00:00 2001 From: Filippo Del Moro Date: Wed, 13 Apr 2022 09:47:04 +0200 Subject: [PATCH 69/87] Fix scanIntoStruct (#5241) * Reproduces error case * Fix scanIntoStruct Co-authored-by: Filippo Del Moro --- scan.go | 2 +- tests/joins_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scan.go b/scan.go index 2ce6bd28..ad3734d8 100644 --- a/scan.go +++ b/scan.go @@ -74,7 +74,7 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - return + continue } relValue.Set(reflect.New(relValue.Type().Elem())) diff --git a/tests/joins_test.go b/tests/joins_test.go index bb5352ef..4908e5ba 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -10,12 +10,12 @@ import ( ) func TestJoins(t *testing.T) { - user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) DB.Create(&user) var user2 User - if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } From a65912c5887f850f6262dca68ca8d0dc10ca1bcc Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 13 Apr 2022 15:52:07 +0800 Subject: [PATCH 70/87] fix: FirstOrCreate RowsAffected (#5250) --- finisher_api.go | 3 +++ tests/create_test.go | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index 5e4c3c5a..d35456a6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,6 +326,9 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Model(dest).Updates(assigns) + } else { + // can not use Find RowsAffected + tx.RowsAffected = 0 } } return tx diff --git a/tests/create_test.go b/tests/create_test.go index 2b23d440..3730172f 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) { t.Fatalf("it is not ErrInvalidValue") } } + +func TestFirstOrCreateRowsAffected(t *testing.T) { + user := User{Name: "TestFirstOrCreateRowsAffected"} + + res := DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 1 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } + + res = DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 0 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } +} From 771cbed755b0b61c9b5c00eea54c92b7774a17fc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:52:40 +0800 Subject: [PATCH 71/87] chore(deps): bump actions/stale from 4 to 5 (#5244) Bumps [actions/stale](https://github.com/actions/stale) from 4 to 5. - [Release notes](https://github.com/actions/stale/releases) - [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/stale/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/stale dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/invalid_question.yml | 2 +- .github/workflows/missing_playground.yml | 2 +- .github/workflows/stale.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 327a70f6..aa1812d4 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 15d3850f..c3c92beb 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c5e0d7ab..af8d3636 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" From ce53ea53ee064d57c8a23eb4c7b5f2deed0eb410 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:53:12 +0800 Subject: [PATCH 72/87] chore(deps): bump actions/setup-go from 2 to 3 (#5243) Bumps [actions/setup-go](https://github.com/actions/setup-go) from 2 to 3. - [Release notes](https://github.com/actions/setup-go/releases) - [Commits](https://github.com/actions/setup-go/compare/v2...v3) --- updated-dependencies: - dependency-name: actions/setup-go dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8bfb2332..b97da3f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -65,7 +65,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -109,7 +109,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} @@ -152,7 +152,7 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} From d421c67ef59259dc65737a639bee75b568ad5c17 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 14 Apr 2022 10:51:39 +0800 Subject: [PATCH 73/87] Remove ErrRecordNotFound error from log when using Save --- finisher_api.go | 2 +- tests/go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index d35456a6..cbe927bf 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -105,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { + if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { return tx.Create(value) } } diff --git a/tests/go.mod b/tests/go.mod index 3ac4633e..0a3f85f9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.12 // indirect - golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 // indirect + golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect gorm.io/driver/mysql v1.3.3 gorm.io/driver/postgres v1.3.4 gorm.io/driver/sqlite v1.3.1 From e0ed3ce400c8cb774ad03bd6c1a5028e6c425988 Mon Sep 17 00:00:00 2001 From: ZhangShenao <15201440436@163.com> Date: Thu, 14 Apr 2022 20:32:57 +0800 Subject: [PATCH 74/87] fix spelling mistake (#5256) Co-authored-by: Shenao Zhang --- callbacks/helper.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/callbacks/helper.go b/callbacks/helper.go index 71b67de5..ae9fd8c5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -125,7 +125,7 @@ func checkMissingWhereConditions(db *gorm.DB) { type visitMap = map[reflect.Value]bool // Check if circular values, return true if loaded -func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { +func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { if v.Kind() == reflect.Ptr { v = v.Elem() } @@ -134,17 +134,17 @@ func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) { case reflect.Slice, reflect.Array: loaded = true for i := 0; i < v.Len(); i++ { - if !loadOrStoreVisitMap(vistMap, v.Index(i)) { + if !loadOrStoreVisitMap(visitMap, v.Index(i)) { loaded = false } } case reflect.Struct, reflect.Interface: if v.CanAddr() { p := v.Addr() - if _, ok := (*vistMap)[p]; ok { + if _, ok := (*visitMap)[p]; ok { return true } - (*vistMap)[p] = true + (*visitMap)[p] = true } } From b49ae84780b212f2460938c74ee41a43a46b1834 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 17 Apr 2022 09:58:33 +0800 Subject: [PATCH 75/87] fix: FindInBatches with offset limit (#5255) * fix: FindInBatches with offset limit * fix: break first * fix: FindInBatches Limit zero --- finisher_api.go | 24 ++++++++++++++++++ tests/query_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index cbe927bf..0bd8f7d9 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -181,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch int ) + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + totalSize = limit.Limit + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -196,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) if result.Statement.Schema.PrioritizedPrimaryField == nil { diff --git a/tests/query_test.go b/tests/query_test.go index af2b8d4b..f66cf83a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithOffsetLimit(t *testing.T) { + users := []User{ + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + } + + DB.Create(&users) + + var ( + sub, results []User + lastBatch int + ) + + // offset limit + if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + results = append(results, sub...) + lastBatch = batch + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + if lastBatch != 3 { + t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) + } + + targetUsers := users[3:8] + for i := 0; i < len(targetUsers); i++ { + AssertEqual(t, results[i], targetUsers[i]) + } + + var sub1 []User + // limit < batchSize + if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub2 []User + // only offset + if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 7 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub3 []User + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } +} + func TestFindInBatchesWithError(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql") From 88c26b62ee63863932e001be21e05a4ef43d03c2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 20 Apr 2022 17:21:38 +0800 Subject: [PATCH 76/87] Support Scopes in group conditions --- statement.go | 4 ++++ tests/sql_builder_test.go | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/statement.go b/statement.go index 9fcee09c..d0c691d8 100644 --- a/statement.go +++ b/statement.go @@ -312,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: + for _, scope := range v.Statement.scopes { + v = scope(v) + } + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a7630271..a9b920dc 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -243,6 +243,21 @@ func TestGroupConditions(t *testing.T) { if !strings.HasSuffix(result, expects) { t.Errorf("expects: %v, got %v", expects, result) } + + stmt2 := dryRunDB.Where( + DB.Scopes(NameIn1And2), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement + + result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...) + expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...) + + if !strings.HasSuffix(result2, expects2) { + t.Errorf("expects: %v, got %v", expects2, result2) + } } func TestCombineStringConditions(t *testing.T) { From 395606ac7ce6c1fcd9bd9c79c16b73cb1bc13bc8 Mon Sep 17 00:00:00 2001 From: glebarez <47985861+glebarez@users.noreply.github.com> Date: Fri, 22 Apr 2022 06:19:33 +0300 Subject: [PATCH 77/87] fix missing error-check in AutoMigrate (#5283) --- migrator/migrator.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index a50bb3ff..93f4c5d0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -99,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := m.DB.Migrator().ColumnTypes(value) + if err != nil { + return err + } for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] From 9b80fe9e96e6d9132f935a944a150777a3ffdf03 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 24 Apr 2022 09:08:52 +0800 Subject: [PATCH 78/87] fix: stmt.Changed zero value filed behavior (#5281) * fix: stmt.Changed zero value filed behavior * chore: rename var --- statement.go | 9 ++++++--- tests/hooks_test.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index d0c691d8..ed3e8716 100644 --- a/statement.go +++ b/statement.go @@ -609,10 +609,10 @@ func (stmt *Statement) Changed(fields ...string) bool { changed := func(field *schema.Field) bool { fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, ok := stmt.Dest.(map[string]interface{}); ok { - if fv, ok := v[field.Name]; ok { + if mv, mok := stmt.Dest.(map[string]interface{}); mok { + if fv, ok := mv[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if fv, ok := v[field.DBName]; ok { + } else if fv, ok := mv[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) } } else { @@ -622,6 +622,9 @@ func (stmt *Statement) Changed(fields ...string) bool { } changedValue, zero := field.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } return !zero && !utils.AssertEqual(changedValue, fieldValue) } } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0e6ab2fe..20e8dc18 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) { t.Errorf("invalid data after update, got %+v", product) } + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""}) + if product.Name != "Product New4" || product.Price != 320 || product.Code != "" { + t.Errorf("invalid data after update, got %+v", product) + } + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) - if product.Price != 270 || product.Code != "L1215" { + if product.Price != 320 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) - if product.Price != 270 || product.Code != "L1216" { + if product.Price != 320 || product.Code != "L1216" { t.Errorf("invalid data after update, got %+v", product) } From 3643f856a3edeaa4db7ede87a4bc2928d2aadc09 Mon Sep 17 00:00:00 2001 From: aelmel <5629597+aelmel@users.noreply.github.com> Date: Sun, 24 Apr 2022 04:10:36 +0300 Subject: [PATCH 79/87] check for pointer to pointer value (#5278) * check for pointer to pointer value * revert to Ptr Co-authored-by: Alexei Melnic --- schema/field.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/schema/field.go b/schema/field.go index fd8b2e6a..d6df6596 100644 --- a/schema/field.go +++ b/schema/field.go @@ -528,6 +528,9 @@ func (field *Field) setupValuerAndSetter() { reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { + if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { + reflectV = reflect.Indirect(reflectV) + } field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { From a0cc631272f44a18597c87b7910b660df729303e Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sun, 24 Apr 2022 12:13:27 +0800 Subject: [PATCH 80/87] test: test for postgrs serial column (#5234) * test: test for postgrs sercial column * test: only for postgres * chore: spelling mistake * test: for drop sequence --- tests/migrate_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index f72c4c08..d6a6c4db 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -574,3 +574,65 @@ func TestMigrateColumnOrder(t *testing.T) { } } } + +// https://github.com/go-gorm/gorm/issues/5047 +func TestMigrateSerialColumn(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Event struct { + ID uint `gorm:"primarykey"` + UID uint32 + } + + type Event1 struct { + ID uint `gorm:"primarykey"` + UID uint32 `gorm:"not null;autoIncrement"` + } + + type Event2 struct { + ID uint `gorm:"primarykey"` + UID uint16 `gorm:"not null;autoIncrement"` + } + + var err error + err = DB.Migrator().DropTable(&Event{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + // create sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // delete sequence + err = DB.Table("events").AutoMigrate(&Event{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // update sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = DB.Table("events").AutoMigrate(&Event2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + + events := make([]*Event, 0) + DB.Table("events").Find(&events) + + AssertEqual(t, 3, len(events)) + for _, v := range events { + AssertEqual(t, v.ID, v.UID) + } +} From 0211ac91a2e2cbde5d6212e5f74a7344cb9795db Mon Sep 17 00:00:00 2001 From: Chiung-Ming Huang Date: Mon, 25 Apr 2022 11:39:23 +0800 Subject: [PATCH 81/87] index: add composite id (#5269) * index: add composite id * index: add test cases of composite id * index: improve the comments for the test cases of composite id --- schema/index.go | 26 ++++++++++++++++--- schema/index_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/schema/index.go b/schema/index.go index 16d096b7..5003c742 100644 --- a/schema/index.go +++ b/schema/index.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "sort" "strconv" "strings" @@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index { for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { - for _, index := range parseFieldIndexes(field) { + fieldIndexes, err := parseFieldIndexes(field) + if err != nil { + schema.err = err + break + } + for _, index := range fieldIndexes { idx := indexes[index.Name] idx.Name = index.Name if idx.Class == "" { @@ -82,7 +88,7 @@ func (schema *Schema) LookIndex(name string) *Index { return nil } -func parseFieldIndexes(field *Field) (indexes []Index) { +func parseFieldIndexes(field *Field) (indexes []Index, err error) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { v := strings.Split(value, ":") @@ -106,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } if name == "" { - name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + subName := field.Name + const key = "COMPOSITE" + if composite, found := settings[key]; found { + if len(composite) == 0 || composite == key { + err = fmt.Errorf( + "The composite tag of %s.%s cannot be empty", + field.Schema.Name, + field.Name) + return + } + subName = composite + } + name = field.Schema.namer.IndexName( + field.Schema.Table, subName) } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { @@ -138,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } } + err = nil return } diff --git a/schema/index_test.go b/schema/index_test.go index 3c4582bb..1fe31cc1 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,6 +19,36 @@ type UserIndex struct { OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` + + // Composite Index: Flattened structure. + Data0A string `gorm:"index:,composite:comp_id0"` + Data0B string `gorm:"index:,composite:comp_id0"` + + // Composite Index: Nested structure. + Data1A string `gorm:"index:,composite:comp_id1"` + CompIdxLevel1C + + // Composite Index: Unique and priority. + Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"` + CompIdxLevel2C +} + +type CompIdxLevel1C struct { + CompIdxLevel1B + Data1C string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel1B struct { + Data1B string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel2C struct { + CompIdxLevel2B + Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"` +} + +type CompIdxLevel2B struct { + Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"` } func TestParseIndex(t *testing.T) { @@ -84,6 +114,36 @@ func TestParseIndex(t *testing.T) { Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, + "idx_user_indices_comp_id0": { + Name: "idx_user_indices_comp_id0", + Type: "", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data0A"}, + }, { + Field: &schema.Field{Name: "Data0B"}, + }}, + }, + "idx_user_indices_comp_id1": { + Name: "idx_user_indices_comp_id1", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data1A"}, + }, { + Field: &schema.Field{Name: "Data1B"}, + }, { + Field: &schema.Field{Name: "Data1C"}, + }}, + }, + "idx_user_indices_comp_id2": { + Name: "idx_user_indices_comp_id2", + Class: "UNIQUE", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data2C"}, + }, { + Field: &schema.Field{Name: "Data2A"}, + }, { + Field: &schema.Field{Name: "Data2B"}, + }}, + }, } indices := user.ParseIndexes() From 6a6dfdae72574e931ea4f0737637308ef2c34b8f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 26 Apr 2022 17:16:48 +0800 Subject: [PATCH 82/87] Refactor FirstOrCreate, FirstOrInit --- finisher_api.go | 24 ++++++++++++------------ tests/go.mod | 7 +++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 0bd8f7d9..663d532b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -290,7 +290,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -312,25 +312,26 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { // FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - queryTx := db.Limit(1).Order(clause.OrderByColumn{ + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); tx.Error == nil { - if tx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if result := queryTx.Find(dest, conds...); result.Error == nil { + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + result.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) } // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) } return tx.Create(dest) @@ -351,8 +352,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Model(dest).Updates(assigns) } else { - // can not use Find RowsAffected - tx.RowsAffected = 0 + tx.Error = result.Error } } return tx diff --git a/tests/go.mod b/tests/go.mod index 0a3f85f9..6a2cf22f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,13 +7,12 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.5 - github.com/mattn/go-sqlite3 v1.14.12 // indirect golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect gorm.io/driver/mysql v1.3.3 - gorm.io/driver/postgres v1.3.4 - gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/postgres v1.3.5 + gorm.io/driver/sqlite v1.3.2 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.3 + gorm.io/gorm v1.23.4 ) replace gorm.io/gorm => ../ From bd7e42ec651f66539009371675bff38645b9b6b8 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 27 Apr 2022 21:13:48 +0800 Subject: [PATCH 83/87] fix: AutoMigrate with special table name (#5301) * fix: AutoMigrate with special table name * test: migrate with special table name --- migrator/migrator.go | 3 ++- tests/migrate_test.go | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 93f4c5d0..d4989410 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -759,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} - if err := dep.Parse(value); err != nil { + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d6a6c4db..6576a2bd 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -636,3 +636,14 @@ func TestMigrateSerialColumn(t *testing.T) { AssertEqual(t, v.ID, v.UID) } } + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + DB.AutoMigrate(&Coupon{}) + DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} From d3488ae6bcee8ccbb1e463a42a048e1958c4c90f Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 30 Apr 2022 09:50:53 +0800 Subject: [PATCH 84/87] fix: add judge result of auto_migrate (#5306) Co-authored-by: fredhan --- tests/migrate_test.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 6576a2bd..28ee28cb 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -639,9 +639,19 @@ func TestMigrateSerialColumn(t *testing.T) { // https://github.com/go-gorm/gorm/issues/5300 func TestMigrateWithSpecialName(t *testing.T) { - DB.AutoMigrate(&Coupon{}) - DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) - DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } AssertEqual(t, true, DB.Migrator().HasTable("coupons")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) From b0104943edf50bba6072d18ca91e949ff8d4e3a2 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Sat, 30 Apr 2022 09:57:16 +0800 Subject: [PATCH 85/87] fix: callbcak sort when using multiple plugin (#5304) --- callbacks.go | 8 +++++++- tests/callbacks_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/callbacks.go b/callbacks.go index f344649e..c060ea70 100644 --- a/callbacks.go +++ b/callbacks.go @@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback func(*callback) error ) sort.Slice(cs, func(i, j int) bool { - return cs[j].before == "*" || cs[j].after == "*" + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false }) for _, c := range cs { diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 02765b8c..2bf9496b 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -38,6 +38,7 @@ func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} +func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { @@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) { } } } + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} From 19b8d37ae8155667d76021e4ca3314bb571756be Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Wed, 4 May 2022 18:57:53 +0800 Subject: [PATCH 86/87] fix: preload with skip hooks (#5310) --- callbacks/query.go | 2 +- tests/hooks_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index fb2bb37a..26ee8c34 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -252,7 +252,7 @@ func Preload(db *gorm.DB) { for _, name := range preloadNames { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 20e8dc18..8e964fd8 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -466,8 +466,9 @@ type Product4 struct { type ProductItem struct { gorm.Model - Code string - Product4ID uint + Code string + Product4ID uint + AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { @@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error { return nil } +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) @@ -498,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } } From 373bcf7aca01ef76c8ba5c3bc1ff191b020afc7b Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Mon, 9 May 2022 10:07:18 +0800 Subject: [PATCH 87/87] fix: many2many auto migrate (#5322) * fix: many2many auto migrate * fix: uuid ossp --- schema/relationship.go | 6 ++++-- schema/utils.go | 9 +++++++++ tests/migrate_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index b5100897..0aa33e51 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -235,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -258,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 2720c530..acf1a739 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "context" + "fmt" "reflect" "regexp" "strings" @@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct return tag } +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 28ee28cb..f862eda0 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -657,3 +657,39 @@ func TestMigrateWithSpecialName(t *testing.T) { AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +}