From a7f01bd1b22ec7131c420de62abe5f7e85573277 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Jul 2023 10:47:19 +0800 Subject: [PATCH 01/26] Test Pluck with customized type --- tests/go.mod | 18 ++++++++++++++++-- tests/query_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index aebe5a06..147d0a79 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm/tests -go 1.16 +go 1.18 require ( github.com/google/uuid v1.3.0 @@ -10,7 +10,21 @@ require ( gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 gorm.io/driver/sqlite v1.5.2 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.2-0.20230610234218-206613868439 + gorm.io/gorm v1.25.2 +) + +require ( + github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/microsoft/go-mssqldb v1.4.0 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/text v0.11.0 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/query_test.go b/tests/query_test.go index b6bd0736..5728378d 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "database/sql/driver" "fmt" "reflect" "regexp" @@ -658,6 +659,18 @@ func TestOrWithAllFields(t *testing.T) { } } +type Int64 int64 + +func (v Int64) Value() (driver.Value, error) { + return v - 1, nil +} + +func (f *Int64) Scan(v interface{}) error { + y := v.(int64) + *f = Int64(y + 1) + return nil +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -685,6 +698,11 @@ func TestPluck(t *testing.T) { t.Errorf("got error when pluck id: %v", err) } + var ids2 []Int64 + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil { + t.Errorf("got error when pluck id: %v", err) + } + for idx, name := range names { if name != users[idx].Name { t.Errorf("Unexpected result on pluck name, got %+v", names) @@ -697,6 +715,12 @@ func TestPluck(t *testing.T) { } } + for idx, id := range ids2 { + if int(id) != int(users[idx].ID+1) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } + var times []time.Time if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { t.Errorf("got error when pluck time: %v", err) From 1fb26ac90e1959a3fb08a4878c00f26ca5284604 Mon Sep 17 00:00:00 2001 From: Saeid Date: Fri, 4 Aug 2023 04:30:07 +0200 Subject: [PATCH 02/26] test: coverage for tabletype added (#6496) * test: coverage for tabletype added * test: tidb exclueded --------- Co-authored-by: Saeid Saeidee --- tests/helper_test.go | 4 ++++ tests/migrate_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/helper_test.go b/tests/helper_test.go index c34e357c..1a4874ee 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -265,6 +265,10 @@ func isTiDB() bool { return os.Getenv("GORM_DIALECT") == "tidb" } +func isMysql() bool { + return os.Getenv("GORM_DIALECT") == "mysql" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 69f86412..849e2b7b 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1598,3 +1598,48 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) { } } } + +func TestTableType(t *testing.T) { + // currently it is only supported for mysql driver + if !isMysql() { + return + } + + const tblName = "cities" + const tblSchema = "gorm" + const tblType = "BASE TABLE" + const tblComment = "foobar comment" + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&City{}) + + if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities tables, got error: %v", err) + } + + tableType, err := DB.Table("cities").Migrator().TableType(&City{}) + if err != nil { + t.Fatalf("failed to get table type, got error %v", err) + } + + if tableType.Schema() != tblSchema { + t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema()) + } + + if tableType.Name() != tblName { + t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name()) + } + + if tableType.Type() != tblType { + t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type()) + } + + comment, ok := tableType.Comment() + if !ok || comment != tblComment { + t.Fatalf("expected comment %s got %s", tblComment, comment) + } +} From 193c454cf48f3e65a98abc0a09e04fe6f5d49c0a Mon Sep 17 00:00:00 2001 From: San Ye Date: Fri, 4 Aug 2023 10:31:18 +0800 Subject: [PATCH 03/26] keep float precision in ExplainSQL (#6495) --- logger/sql.go | 6 ++++-- logger/sql_test.go | 21 ++++++++++++++------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index bcacc7cf..1521c1fd 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -93,8 +93,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: vars[idx] = utils.ToString(v) - case float64, float32: - vars[idx] = fmt.Sprintf("%.6f", v) + case float32: + vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: diff --git a/logger/sql_test.go b/logger/sql_test.go index c5b181a9..e4a72748 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -57,44 +57,51 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, - Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + } for idx, r := range results { From f47376181317ccbf08dcdb28b0c0171dc5d61fda Mon Sep 17 00:00:00 2001 From: Aayush Acharya <33954116+aayushacharya@users.noreply.github.com> Date: Fri, 4 Aug 2023 08:20:59 +0545 Subject: [PATCH 04/26] fix: added `SkipHooks` in db `getInstance()` (#6484) --- gorm.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gorm.go b/gorm.go index 9297850e..203527af 100644 --- a/gorm.go +++ b/gorm.go @@ -399,11 +399,12 @@ func (db *DB) getInstance() *DB { if db.clone == 1 { // clone with new statement tx.Statement = &Statement{ - DB: tx, - ConnPool: db.Statement.ConnPool, - Context: db.Statement.Context, - Clauses: map[string]clause.Clause{}, - Vars: make([]interface{}, 0, 8), + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + SkipHooks: db.Statement.SkipHooks, } } else { // with clone statement From 3c34bc2f59fd080dce5e7a829a8f178b3f4de194 Mon Sep 17 00:00:00 2001 From: fayvori <80601865+fayvori@users.noreply.github.com> Date: Mon, 7 Aug 2023 11:35:19 +0300 Subject: [PATCH 05/26] refactor: Regex description (#6507) * Mirror cleanup * Regex description --------- Co-authored-by: Ignat Belousov --- logger/sql.go | 2 ++ logger/sql_test.go | 1 - migrator/migrator.go | 9 +++++++++ tests/go.mod | 6 +++--- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 1521c1fd..13e5d957 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -28,8 +28,10 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability diff --git a/logger/sql_test.go b/logger/sql_test.go index e4a72748..d9afe393 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,7 +101,6 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, - } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index de60f91c..b15a43ef 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,8 +16,17 @@ import ( "gorm.io/gorm/schema" ) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) +// TODO:? Create const vars for raw sql queries ? + // Migrator m struct type Migrator struct { Config diff --git a/tests/go.mod b/tests/go.mod index 147d0a79..7a89ee05 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.4.0 // indirect - golang.org/x/crypto v0.11.0 // indirect - golang.org/x/text v0.11.0 // indirect + github.com/microsoft/go-mssqldb v1.5.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/text v0.12.0 // indirect ) replace gorm.io/gorm => ../ From 15162afaf2a1cd1ee8c63ebc0dc14b8baa0613f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Aug 2023 13:30:48 +0800 Subject: [PATCH 06/26] Support GetDBConnWithContext PreparedStmtDB --- prepare_stmt.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 10fefc31..9d98c86e 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() - } - +func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(gormdb) + } + + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { From bae684b3639dff3e35d0ed330bc82c12e8282110 Mon Sep 17 00:00:00 2001 From: weih Date: Thu, 10 Aug 2023 13:34:33 +0800 Subject: [PATCH 07/26] fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) --- clause/expression.go | 16 ++++++++++------ clause/expression_test.go | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 92ac7f22..8d010522 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/clause/expression_test.go b/clause/expression_test.go index aaede61c..b997bf11 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -199,6 +199,11 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{}}, + }, + Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, From fef42941ba87bff8dad5d48b057a2c2056345984 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:31 +0800 Subject: [PATCH 08/26] feat: rm GetDBConnWithContext method (#6535) * feat: rm contextconnpool method * feat: nil --- go.sum | 2 -- gorm.go | 9 ++++++--- interfaces.go | 6 ------ prepare_stmt.go | 23 ++++++++++++++++++----- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/go.sum b/go.sum index fb4240eb..bd6104c9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 203527af..32193870 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/interfaces.go b/interfaces.go index 1950d740..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/prepare_stmt.go b/prepare_stmt.go index 9d98c86e..aa944624 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } - if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(gormdb) - } - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } @@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -176,6 +185,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() From 2c2089760c5a35b3884c7a949621ce0e790e7835 Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:57 +0800 Subject: [PATCH 09/26] add float32 test case (#6530) --- logger/sql_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/logger/sql_test.go b/logger/sql_test.go index d9afe393..a82fa546 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,6 +101,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 7e44f73ad3b657a86bbdc881787b03c25ab789a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=9A=E4=B8=80=E6=B6=9B?= Date: Sat, 19 Aug 2023 21:35:14 +0800 Subject: [PATCH 10/26] fix schema GetIdentityFieldValuesMap interface or ptr (#6417) Co-authored-by: uptutu --- schema/utils.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/schema/utils.go b/schema/utils.go index 65d012e5..7fdda185 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} From ac07543962994da4c6994ba3907417d7835a2619 Mon Sep 17 00:00:00 2001 From: Rataj Date: Sun, 20 Aug 2023 13:46:56 +0200 Subject: [PATCH 11/26] Fixed error message when dialector fails to initialize (#6509) Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides even though "db" doesn't exist. Previously, this code leads to: panic: runtime error: invalid memory address or nil pointer dereference This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case: [error] failed to initialize database, got error default addr for network 'localhost' unknown --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 32193870..775cd3de 100644 --- a/gorm.go +++ b/gorm.go @@ -182,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } From 653732e1c33858f5743a34f9fbfe66428d041760 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Aug 2023 20:19:29 +0800 Subject: [PATCH 12/26] Update go testing versions --- .github/workflows/tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1191a8ea..e98a17d6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7'] - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -85,7 +85,7 @@ jobs: strategy: matrix: dbversion: [ 'mariadb:latest' ] - go: [ '1.19', '1.18' ] + go: ['1.21', '1.20', '1.19'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -170,7 +170,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.19', '1.18'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} @@ -214,7 +214,7 @@ jobs: strategy: matrix: dbversion: [ 'v6.5.0' ] - go: [ '1.19', '1.18' ] + go: ['1.21', '1.20', '1.19'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} From e57e5d8884d801caa4ce0307bcd081f7e889e514 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Aug 2023 15:40:54 +0800 Subject: [PATCH 13/26] Update go.mod --- go.mod | 2 +- tests/go.mod | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 85e4242a..deb61b74 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.16 +go 1.18 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/tests/go.mod b/tests/go.mod index 7a89ee05..aef26e3e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,14 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.3.1 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 - gorm.io/driver/sqlite v1.5.2 + gorm.io/driver/sqlite v1.5.3 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.2 + gorm.io/gorm v1.25.4 ) require ( @@ -19,7 +19,7 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.2 // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/microsoft/go-mssqldb v1.5.0 // indirect From 2095d42b4c15de8d0cdaf64fd75e306bec40d9c4 Mon Sep 17 00:00:00 2001 From: Samuel N Cui Date: Mon, 9 Oct 2023 17:26:27 +0800 Subject: [PATCH 14/26] fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type (#6624) * fix: sqlite dialector cannot apply `PRIMARY KEY AUTOINCREMENT` type fix #4760 * feat: add auto increment test * feat: update sqlite * feat: update tests deps sqlite to v1.5.4 --- migrator/migrator.go | 2 +- tests/go.mod | 8 ++++---- tests/migrate_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index b15a43ef..49bc9371 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -217,7 +217,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } diff --git a/tests/go.mod b/tests/go.mod index aef26e3e..5a0aeddd 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.10.9 gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 - gorm.io/driver/sqlite v1.5.3 + gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a gorm.io/gorm v1.25.4 ) @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.5.0 // indirect - golang.org/x/crypto v0.12.0 // indirect - golang.org/x/text v0.12.0 // indirect + github.com/microsoft/go-mssqldb v1.6.0 // indirect + golang.org/x/crypto v0.14.0 // indirect + golang.org/x/text v0.13.0 // indirect ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 849e2b7b..cfd3e0ac 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -862,6 +862,48 @@ func TestMigrateWithSpecialName(t *testing.T) { AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } +// https://github.com/go-gorm/gorm/issues/4760 +func TestMigrateAutoIncrement(t *testing.T) { + type AutoIncrementStruct struct { + ID int64 `gorm:"primarykey;autoIncrement"` + Field1 uint32 `gorm:"column:field1"` + Field2 float32 `gorm:"column:field2"` + } + + if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil { + t.Fatalf("AutoMigrate err: %v", err) + } + + const ROWS = 10 + for idx := 0; idx < ROWS; idx++ { + if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + } + + rows := make([]*AutoIncrementStruct, 0, ROWS) + if err := DB.Order("id ASC").Find(&rows).Error; err != nil { + t.Fatalf("find auto_increment_struct fail, err: %v", err) + } + + ids := make([]int64, 0, len(rows)) + for _, row := range rows { + ids = append(ids, row.ID) + } + lastID := ids[len(ids)-1] + + if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("delete auto_increment_struct fail, err: %v", err) + } + + newRow := &AutoIncrementStruct{} + if err := DB.Create(newRow).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + + AssertEqual(t, newRow.ID, lastID+1) +} + // https://github.com/go-gorm/gorm/issues/5320 func TestPrimarykeyID(t *testing.T) { if DB.Dialector.Name() != "postgres" { From 9d8a5bb208f5616638cbaad878a12d5ac73970d3 Mon Sep 17 00:00:00 2001 From: "hjwblog.com" Date: Tue, 10 Oct 2023 14:45:48 +0800 Subject: [PATCH 15/26] feat: reuse name (#6626) --- clause/expression.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clause/expression.go b/clause/expression.go index 8d010522..3140846e 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true - name = []byte{} + name = name[:0] } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { From 12ba285a52fb25c3422e16226666ba791f376c0b Mon Sep 17 00:00:00 2001 From: Mathias Zeller <62462901+matoubidou@users.noreply.github.com> Date: Tue, 10 Oct 2023 08:46:32 +0200 Subject: [PATCH 16/26] *datatypes.JSON in model causes panic on tx.Statement.Changed (#6611) * do not panic on nil * more explanation in comments * get things compact --- utils/utils.go | 33 +++++++++++++++++++++------------ utils/utils_test.go | 1 + 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/utils/utils.go b/utils/utils.go index ddbca60a..c8fec5b0 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } - - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } - - return reflect.DeepEqual(src, dst) +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true } - return true + if x == nil || y == nil { + return false + } + + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } + + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() + } + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { diff --git a/utils/utils_test.go b/utils/utils_test.go index 71eef964..d0486822 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -98,6 +98,7 @@ func TestAssertEqual(t *testing.T) { {"error not equal", errors.New("1"), errors.New("2"), false}, {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + {"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, } for _, test := range assertEqualTests { t.Run(test.name, func(t *testing.T) { From 8c18714462de07fa3392b99eda089f2f9e3b6042 Mon Sep 17 00:00:00 2001 From: Jeremy Quirke Date: Mon, 9 Oct 2023 23:50:29 -0700 Subject: [PATCH 17/26] Don't call MethodByName with a variable arg (#6602) Go 1.22 goes somewhat toward addressing the issue using reflect MethodByName disabling linker deadcode elimination (DCE) and the resultant large increase in binary size because the linker cannot prune unused code because it might be reached via reflection. Go Issue golang/go#62257 reduces the number of incidences of this problem by leveraging a compiler assist to avoid marking functions containing calls to MethodByName as ReflectMethods as long as the arguments are constants. An analysis of Uber Technologies code base however shows that a number of transitive imports still contain calls to MethodByName with a variable argument, including GORM. In the case of GORM, the solution we are proposing is because the number of possible methods is finite, we will "unroll" this. This demonstrably shows that GORM is not longer a problem for DCE. Before ``` % go version go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64 % go test ./... -ldflags=-dumpdep 2> >(grep -i -e '->.*') gorm.io/gorm.(*Statement).BuildCondition -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ok gorm.io/gorm (cached) ok gorm.io/gorm/callbacks (cached) gorm.io/gorm/clause_test.BenchmarkComplexSelect -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ? gorm.io/gorm/migrator [no test files] ok gorm.io/gorm/clause (cached) ok gorm.io/gorm/logger (cached) gorm.io/gorm/schema_test.TestAdvancedDataTypeValuerAndSetter -> gorm.io/gorm/schema.ParseWithSpecialTableName type:reflect.Value -> reflect.(*Value).Method type:reflect.Value -> reflect.(*Value).MethodByName ? gorm.io/gorm/utils/tests [no test files] ok gorm.io/gorm/schema (cached) ok gorm.io/gorm/utils (cached) ``` After ``` %go version go version devel go1.22-2f3458a8ce Sat Sep 16 16:26:48 2023 -0700 darwin/arm64 %go test ./... -ldflags=-dumpdep 2> >(grep -i -e '->.*') ok gorm.io/gorm (cached) ok gorm.io/gorm/callbacks (cached) ? gorm.io/gorm/migrator [no test files] ? gorm.io/gorm/utils/tests [no test files] ok gorm.io/gorm/clause (cached) ok gorm.io/gorm/logger (cached) ok gorm.io/gorm/schema (cached) ok gorm.io/gorm/utils (cached) ``` --- schema/schema.go | 63 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index e13a5ed1..3e7459ce 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -13,6 +13,20 @@ import ( "gorm.io/gorm/logger" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") @@ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} - for _, name := range callbacks { - if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + callbackTypes := []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, + } + for _, cbName := range callbackTypes { + if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } @@ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { From 1b240810106fd68f84cfe73bcacaf91a8e4ce1dd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:50:45 +0800 Subject: [PATCH 18/26] chore(deps): bump actions/checkout from 3 to 4 (#6586) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- .github/workflows/reviewdog.yml | 2 +- .github/workflows/tests.yml | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0e8aaa60..ef852765 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index a6542d57..3a65f0bc 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e98a17d6..380231b9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -70,7 +70,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -113,7 +113,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -156,7 +156,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -199,7 +199,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v3 @@ -231,7 +231,7 @@ jobs: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache From 6bef318891b98263f3568c13093b5860245d2c52 Mon Sep 17 00:00:00 2001 From: Franco Liberali Date: Tue, 10 Oct 2023 09:03:34 +0200 Subject: [PATCH 19/26] add support for returning in sqlserver (#6585) --- tests/delete_test.go | 6 +++--- tests/update_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/delete_test.go b/tests/delete_test.go index 5cb4b91e..5d112b4e 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestSoftDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } @@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) { } func TestDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } diff --git a/tests/update_test.go b/tests/update_test.go index c03d2d47..a3fb7015 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestUpdateReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } From 78e905919fc253332fb032d0f4a76e7753e437e4 Mon Sep 17 00:00:00 2001 From: gleb <47985861+glebarez@users.noreply.github.com> Date: Thu, 26 Oct 2023 06:54:15 +0300 Subject: [PATCH 20/26] tests/sqilte: enable FOREIGN_KEYS inside OpenTestConnection (#6641) --- tests/tests_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_test.go b/tests/tests_test.go index 47c2a7c1..f9c6cab5 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -43,9 +43,6 @@ func init() { } RunMigrations() - if DB.Dialector.Name() == "sqlite" { - DB.Exec("PRAGMA foreign_keys = ON") - } } } @@ -89,7 +86,10 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) + if err == nil { + db.Exec("PRAGMA foreign_keys = ON") + } } if err != nil { From 5adc0ce5f6c8cf97f1f6b9e835750406612c2fe0 Mon Sep 17 00:00:00 2001 From: Cr <631807682@qq.com> Date: Thu, 26 Oct 2023 11:58:13 +0800 Subject: [PATCH 21/26] test: fix TestEmbeddedRelations (#6639) --- tests/embedded_struct_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 4314f88c..873bba2a 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -236,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) { } func TestEmbeddedRelations(t *testing.T) { + type EmbUser struct { + gorm.Model + Name string + Age uint + Languages []Language `gorm:"many2many:EmbUserSpeak;"` + } + type AdvancedUser struct { - User `gorm:"embedded"` + EmbUser `gorm:"embedded"` Advanced bool } From 9fea15ae75fb9ff2bd86dcaa167673c8ed77394f Mon Sep 17 00:00:00 2001 From: black-06 Date: Mon, 30 Oct 2023 17:15:49 +0800 Subject: [PATCH 22/26] feat: add MigrateColumnUnique (#6640) * feat: add MigrateColumnUnique * feat: define new methods * delete debug in test --- migrator.go | 2 ++ migrator/migrator.go | 22 ++++++++++++++++++++++ schema/naming.go | 8 ++++++++ tests/associations_belongs_to_test.go | 2 -- tests/count_test.go | 2 +- tests/preload_test.go | 1 - tests/update_test.go | 2 +- 7 files changed, 34 insertions(+), 5 deletions(-) diff --git a/migrator.go b/migrator.go index 0e01f567..3d2b032b 100644 --- a/migrator.go +++ b/migrator.go @@ -87,6 +87,8 @@ type Migrator interface { DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. + MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index 49bc9371..64a5a4b5 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -27,6 +27,8 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // TODO:? Create const vars for raw sql queries ? +var _ gorm.Migrator = (*Migrator)(nil) + // Migrator m struct type Migrator struct { Config @@ -539,6 +541,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique && !field.Unique { + return m.DB.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return m.DB.Migrator().CreateConstraint(value, constraint) + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) diff --git a/schema/naming.go b/schema/naming.go index a2a0150a..e6fb81b2 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -19,6 +19,7 @@ type Namer interface { RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string + UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer @@ -26,6 +27,8 @@ type Replacer interface { Replace(name string) string } +var _ Namer = (*NamingStrategy)(nil) + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string @@ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } +// UniqueName generate unique constraint name +func (ns NamingStrategy) UniqueName(table, column string) string { + return ns.formatName("uni", table, ns.toDBName(column)) +} + func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index 6befb5f2..103da032 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) { t.Fatalf("failed to create items, got error: %v", err) } - tx = tx.Debug() - // test replace if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ Logo: "updated logo", diff --git a/tests/count_test.go b/tests/count_test.go index b0dfb0b5..4449515b 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) { } var count2 int64 - if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) } if count2 != 2 { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7304e350..3ff86492 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) { }, } - DB = DB.Debug() for _, test := range tests { t.Run(test.name, func(t *testing.T) { actual := Org{} diff --git a/tests/update_test.go b/tests/update_test.go index a3fb7015..b719cc45 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) { saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { var newOwner TokenOwner if err := DB.Transaction(func(tx *gorm.DB) error { - if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { return err } if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { From d2fb7a942b8d44f9ad7f6f5bc6f9f99ddcebc95a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Flc=E3=82=9B?= Date: Tue, 7 Nov 2023 10:19:41 +0800 Subject: [PATCH 23/26] chore(logger): optimize (#6675) * chore(logger): optimize * chore(logger): optimize --- logger/logger.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index aa0060bc..253f0325 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -69,7 +69,7 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to io.Discard + // Discard logger will print any log to io.Discard Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ @@ -78,7 +78,7 @@ var ( IgnoreRecordNotFoundError: false, Colorful: true, }) - // Recorder Recorder logger records running SQL into a recorder instance + // Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) @@ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { +// +//nolint:cyclop +func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= Silent { return } @@ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } } -// Trace print sql message -func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { +// ParamsFilter filter params +func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Config.ParameterizedQueries { return sql, nil } @@ -198,8 +200,8 @@ type traceRecorder struct { Err error } -// New new trace recorder -func (l traceRecorder) New() *traceRecorder { +// New trace recorder +func (l *traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } From 40f4afe8c21d96db63174bd501fb61d6e73c5587 Mon Sep 17 00:00:00 2001 From: Kijima Daigo Date: Tue, 7 Nov 2023 11:20:06 +0900 Subject: [PATCH 24/26] docs: fix broken link (#6673) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 85ad3050..745dad60 100644 --- a/README.md +++ b/README.md @@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) From c1e911f6ed8d3d929aebbd39985a33c9ebe3bad7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Nov 2023 18:46:39 +0800 Subject: [PATCH 25/26] Update tests/go.mod --- tests/go.mod | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 5a0aeddd..71079050 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,14 @@ module gorm.io/gorm/tests go 1.18 require ( - github.com/google/uuid v1.3.1 + github.com/google/uuid v1.4.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 - gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 + gorm.io/driver/mysql v1.5.2 + gorm.io/driver/postgres v1.5.4 gorm.io/driver/sqlite v1.5.4 - gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a - gorm.io/gorm v1.25.4 + gorm.io/driver/sqlserver v1.5.2 + gorm.io/gorm v1.25.5 ) require ( @@ -19,12 +19,14 @@ require ( github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jackc/pgx/v5 v5.5.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 From 3207ad6033aad5e76c6c9d578ef663032765e484 Mon Sep 17 00:00:00 2001 From: FangSqing <148066072+FangSqing@users.noreply.github.com> Date: Wed, 15 Nov 2023 21:32:56 +0800 Subject: [PATCH 26/26] map insert support return increment id (#6662) --- callbacks/create.go | 70 +++++++++++++---- schema/field.go | 4 +- tests/create_test.go | 180 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 237 insertions(+), 17 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index f0b78139..b1488b08 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && - db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) + if db.RowsAffected == 0 { + return + } + + var ( + pkField *schema.Field + pkFieldName = "@id" + ) + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + if !insertOk { + db.AddError(err) + return + } + + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: + if pkField == nil { return } @@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID -= pkField.AutoIncrementIncrement } } } else { @@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } diff --git a/schema/field.go b/schema/field.go index dd08e056..657e0a4b 100644 --- a/schema/field.go +++ b/schema/field.go @@ -49,6 +49,8 @@ const ( Bytes DataType = "bytes" ) +const DefaultAutoIncrementIncrement int64 = 1 + // Field is the representation of model schema's field type Field struct { Name string @@ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], - AutoIncrementIncrement: 1, + AutoIncrementIncrement: DefaultAutoIncrementIncrement, } for field.IndirectFieldType.Kind() == reflect.Ptr { diff --git a/tests/create_test.go b/tests/create_test.go index 02613b72..d9b54b7f 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "fmt" "regexp" "testing" "time" @@ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { } } -func TestCreateOnConfilctWithDefalutNull(t *testing.T) { +func TestCreateOnConflictWithDefaultNull(t *testing.T) { type OnConfilctUser struct { ID string Name string `gorm:"default:null"` @@ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) { AssertEqual(t, u2.Email, "on-confilct-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } + +func TestCreateFromMapWithoutPK(t *testing.T) { + if !isMysql() { + t.Skipf("This test case skipped, because of only supportting for mysql") + } + + // case 1: one record, create from map[string]interface{} + mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} + if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue1["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + var idVal int64 + _, ok := mapValue1["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue1["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result1.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case2: one record, create from *map[string]interface{} + mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} + if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue2["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + _, ok = mapValue2["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue2["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result2.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: records + values := []map[string]interface{}{ + {"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, + } + + beforeLen := len(values) + if err := DB.Model(&User{}).Create(&values).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + // mariadb with returning, values will be appended with id map + if len(values) == beforeLen*2 { + t.Skipf("This test case skipped, because the db supports returning") + } + + for i := range values { + v, ok := values[i]["id"] + if !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result User + if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + if int64(result.ID) != v.(int64) { + t.Fatal("failed to create data from map with table, @id != id") + } + } +} + +func TestCreateFromMapWithTable(t *testing.T) { + if !isMysql() { + t.Skipf("This test case skipped, because of only supportting for mysql") + } + tableDB := DB.Table("`users`") + + // case 1: create from map[string]interface{} + record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + if err := tableDB.Create(record).Error; err != nil { + t.Fatalf("failed to create data from map with table, got error: %v", err) + } + + if _, ok := record["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if int64(res["id"].(uint64)) != record["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 2: create from *map[string]interface{} + record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} + tableDB2 := DB.Table("users") + if err := tableDB2.Create(&record1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + if _, ok := record1["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res1 map[string]interface{} + if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if int64(res1["id"].(uint64)) != record1["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: create from []map[string]interface{} + records := []map[string]interface{}{ + {"name": "create_from_map_with_table_2", "age": 19}, + {"name": "create_from_map_with_table_3", "age": 20}, + } + + tableDB = DB.Table("users") + if err := tableDB.Create(&records).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + if _, ok := records[0]["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + if _, ok := records[1]["@id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res2 map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var res3 map[string]interface{} + if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + if int64(res2["id"].(uint64)) != records[0]["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } + + if int64(res3["id"].(uint64)) != records[1]["@id"] { + t.Fatal("failed to create data from map with table, @id != id") + } +}