From 3c34bc2f59fd080dce5e7a829a8f178b3f4de194 Mon Sep 17 00:00:00 2001 From: fayvori <80601865+fayvori@users.noreply.github.com> Date: Mon, 7 Aug 2023 11:35:19 +0300 Subject: [PATCH 1/7] refactor: Regex description (#6507) * Mirror cleanup * Regex description --------- Co-authored-by: Ignat Belousov --- logger/sql.go | 2 ++ logger/sql_test.go | 1 - migrator/migrator.go | 9 +++++++++ tests/go.mod | 6 +++--- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 1521c1fd..13e5d957 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -28,8 +28,10 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability diff --git a/logger/sql_test.go b/logger/sql_test.go index e4a72748..d9afe393 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,7 +101,6 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, - } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index de60f91c..b15a43ef 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,8 +16,17 @@ import ( "gorm.io/gorm/schema" ) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) +// TODO:? Create const vars for raw sql queries ? + // Migrator m struct type Migrator struct { Config diff --git a/tests/go.mod b/tests/go.mod index 147d0a79..7a89ee05 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.4.0 // indirect - golang.org/x/crypto v0.11.0 // indirect - golang.org/x/text v0.11.0 // indirect + github.com/microsoft/go-mssqldb v1.5.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/text v0.12.0 // indirect ) replace gorm.io/gorm => ../ From 15162afaf2a1cd1ee8c63ebc0dc14b8baa0613f7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Aug 2023 13:30:48 +0800 Subject: [PATCH 2/7] Support GetDBConnWithContext PreparedStmtDB --- prepare_stmt.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index 10fefc31..9d98c86e 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() - } - +func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { + return connector.GetDBConnWithContext(gormdb) + } + + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { From bae684b3639dff3e35d0ed330bc82c12e8282110 Mon Sep 17 00:00:00 2001 From: weih Date: Thu, 10 Aug 2023 13:34:33 +0800 Subject: [PATCH 3/7] fix(clause): when the value of clause.Eq is an empty array, the SQL should be IN (NULL) (#6503) --- clause/expression.go | 16 ++++++++++------ clause/expression_test.go | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 92ac7f22..8d010522 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/clause/expression_test.go b/clause/expression_test.go index aaede61c..b997bf11 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -199,6 +199,11 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{}}, + }, + Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, From fef42941ba87bff8dad5d48b057a2c2056345984 Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:31 +0800 Subject: [PATCH 4/7] feat: rm GetDBConnWithContext method (#6535) * feat: rm contextconnpool method * feat: nil --- go.sum | 2 -- gorm.go | 9 ++++++--- interfaces.go | 6 ------ prepare_stmt.go | 23 ++++++++++++++++++----- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/go.sum b/go.sum index fb4240eb..bd6104c9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 203527af..32193870 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/interfaces.go b/interfaces.go index 1950d740..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/prepare_stmt.go b/prepare_stmt.go index 9d98c86e..aa944624 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } } -func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } - if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(gormdb) - } - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } @@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -176,6 +185,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() From 2c2089760c5a35b3884c7a949621ce0e790e7835 Mon Sep 17 00:00:00 2001 From: Heliner <32272517+Heliner@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:33:57 +0800 Subject: [PATCH 5/7] add float32 test case (#6530) --- logger/sql_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/logger/sql_test.go b/logger/sql_test.go index d9afe393..a82fa546 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -101,6 +101,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 7e44f73ad3b657a86bbdc881787b03c25ab789a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=9A=E4=B8=80=E6=B6=9B?= Date: Sat, 19 Aug 2023 21:35:14 +0800 Subject: [PATCH 6/7] fix schema GetIdentityFieldValuesMap interface or ptr (#6417) Co-authored-by: uptutu --- schema/utils.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/schema/utils.go b/schema/utils.go index 65d012e5..7fdda185 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} From ac07543962994da4c6994ba3907417d7835a2619 Mon Sep 17 00:00:00 2001 From: Rataj Date: Sun, 20 Aug 2023 13:46:56 +0200 Subject: [PATCH 7/7] Fixed error message when dialector fails to initialize (#6509) Let's say we have a problem with DSN which leads to dialector initialize error. However DB connection is not created and for some reason line 184 error provides even though "db" doesn't exist. Previously, this code leads to: panic: runtime error: invalid memory address or nil pointer dereference This fix now doesn't attempt to close non-existant database connection and instead continues, so the proper error is shown. In my case: [error] failed to initialize database, got error default addr for network 'localhost' unknown --- gorm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gorm.go b/gorm.go index 32193870..775cd3de 100644 --- a/gorm.go +++ b/gorm.go @@ -182,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } }