Merge branch 'master' into customize-logger-time-format
This commit is contained in:
commit
9b5bc567e5
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
|
||||
|
||||
switch eq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
builder.WriteString(" IN (")
|
||||
rv := reflect.ValueOf(eq.Value)
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
if rv.Len() == 0 {
|
||||
builder.WriteString(" IN (NULL)")
|
||||
} else {
|
||||
builder.WriteString(" IN (")
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
default:
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
|
@ -199,6 +199,11 @@ func TestExpression(t *testing.T) {
|
||||
},
|
||||
ExpectedVars: []interface{}{"a", "b"},
|
||||
Result: "`column-name` NOT IN (?,?)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: column, Value: []string{}},
|
||||
},
|
||||
Result: "`column-name` IN (NULL)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||
|
2
go.sum
2
go.sum
@ -1,6 +1,4 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
|
11
gorm.go
11
gorm.go
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
err = config.Dialector.Initialize(db)
|
||||
|
||||
if err != nil {
|
||||
if db, err := db.DB(); err == nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
||||
return connector.GetDBConnWithContext(db)
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
|
@ -77,12 +77,6 @@ type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
||||
// account the current database context
|
||||
type GetDBConnectorWithContext interface {
|
||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
||||
}
|
||||
|
||||
// Rows rows interface
|
||||
type Rows interface {
|
||||
Columns() ([]string, error)
|
||||
|
@ -30,8 +30,10 @@ func isPrintable(s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of Go types that should be converted to SQL primitives
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
|
@ -102,6 +102,12 @@ func TestExplainSQL(t *testing.T) {
|
||||
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
}
|
||||
|
||||
for idx, r := range results {
|
||||
|
@ -16,8 +16,17 @@ import (
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||
// with a possible trailing non-digit character (\D?).
|
||||
|
||||
// For example, values that can pass this regular expression are:
|
||||
// - "123"
|
||||
// - "abc456"
|
||||
// -"%$#@789"
|
||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
|
||||
// TODO:? Create const vars for raw sql queries ?
|
||||
|
||||
// Migrator m struct
|
||||
type Migrator struct {
|
||||
Config
|
||||
|
@ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
@ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
for _, stmt := range db.Stmts {
|
||||
for _, stmt := range sdb.Stmts {
|
||||
go stmt.Close()
|
||||
}
|
||||
db.PreparedSQL = make([]string, 0, 100)
|
||||
db.Stmts = make(map[string]*Stmt)
|
||||
sdb.PreparedSQL = make([]string, 0, 100)
|
||||
sdb.Stmts = make(map[string]*Stmt)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
@ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
@ -172,6 +185,10 @@ type PreparedStmtTX struct {
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
|
@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
||||
notZero, zero bool
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Ptr ||
|
||||
reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
|
@ -22,9 +22,9 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.4.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.4.0 // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.5.0 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/text v0.12.0 // indirect
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
Loading…
x
Reference in New Issue
Block a user