diff --git a/clause/expression.go b/clause/expression.go index 92ac7f22..8d010522 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/clause/expression_test.go b/clause/expression_test.go index aaede61c..b997bf11 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -199,6 +199,11 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{}}, + }, + Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, diff --git a/go.sum b/go.sum index fb4240eb..bd6104c9 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 203527af..775cd3de 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { err = config.Dialector.Initialize(db) if err != nil { - if db, err := db.DB(); err == nil { + if db, _ := db.DB(); db != nil { _ = db.Close() } } @@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { - return connector.GetDBConnWithContext(db) + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { diff --git a/interfaces.go b/interfaces.go index 1950d740..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -77,12 +77,6 @@ type GetDBConnector interface { GetDBConn() (*sql.DB, error) } -// GetDBConnectorWithContext represents SQL db connector which takes into -// account the current database context -type GetDBConnectorWithContext interface { - GetDBConnWithContext(db *DB) (*sql.DB, error) -} - // Rows rows interface type Rows interface { Columns() ([]string, error) diff --git a/logger/sql.go b/logger/sql.go index 35d9539a..b816bbbb 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -30,8 +30,10 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability diff --git a/logger/sql_test.go b/logger/sql_test.go index dabc8974..55cb93be 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -102,6 +102,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index de60f91c..b15a43ef 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -16,8 +16,17 @@ import ( "gorm.io/gorm/schema" ) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) +// TODO:? Create const vars for raw sql queries ? + // Migrator m struct type Migrator struct { Config diff --git a/prepare_stmt.go b/prepare_stmt.go index 10fefc31..aa944624 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { } func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() - } - if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -172,6 +185,10 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() diff --git a/schema/utils.go b/schema/utils.go index 65d012e5..7fdda185 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} diff --git a/tests/go.mod b/tests/go.mod index 147d0a79..7a89ee05 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -22,9 +22,9 @@ require ( github.com/jackc/pgx/v5 v5.4.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect - github.com/microsoft/go-mssqldb v1.4.0 // indirect - golang.org/x/crypto v0.11.0 // indirect - golang.org/x/text v0.11.0 // indirect + github.com/microsoft/go-mssqldb v1.5.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/text v0.12.0 // indirect ) replace gorm.io/gorm => ../