Merge branch 'master' into customize-logger-time-format
This commit is contained in:
commit
9b5bc567e5
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
|
|||||||
|
|
||||||
switch eq.Value.(type) {
|
switch eq.Value.(type) {
|
||||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||||
builder.WriteString(" IN (")
|
|
||||||
rv := reflect.ValueOf(eq.Value)
|
rv := reflect.ValueOf(eq.Value)
|
||||||
for i := 0; i < rv.Len(); i++ {
|
if rv.Len() == 0 {
|
||||||
if i > 0 {
|
builder.WriteString(" IN (NULL)")
|
||||||
builder.WriteByte(',')
|
} else {
|
||||||
|
builder.WriteString(" IN (")
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
builder.AddVar(builder, rv.Index(i).Interface())
|
||||||
}
|
}
|
||||||
builder.AddVar(builder, rv.Index(i).Interface())
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
builder.WriteByte(')')
|
|
||||||
default:
|
default:
|
||||||
if eqNil(eq.Value) {
|
if eqNil(eq.Value) {
|
||||||
builder.WriteString(" IS NULL")
|
builder.WriteString(" IS NULL")
|
||||||
|
@ -199,6 +199,11 @@ func TestExpression(t *testing.T) {
|
|||||||
},
|
},
|
||||||
ExpectedVars: []interface{}{"a", "b"},
|
ExpectedVars: []interface{}{"a", "b"},
|
||||||
Result: "`column-name` NOT IN (?,?)",
|
Result: "`column-name` NOT IN (?,?)",
|
||||||
|
}, {
|
||||||
|
Expressions: []clause.Expression{
|
||||||
|
clause.Eq{Column: column, Value: []string{}},
|
||||||
|
},
|
||||||
|
Result: "`column-name` IN (NULL)",
|
||||||
}, {
|
}, {
|
||||||
Expressions: []clause.Expression{
|
Expressions: []clause.Expression{
|
||||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||||
|
2
go.sum
2
go.sum
@ -1,6 +1,4 @@
|
|||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
11
gorm.go
11
gorm.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
err = config.Dialector.Initialize(db)
|
err = config.Dialector.Initialize(db)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if db, err := db.DB(); err == nil {
|
if db, _ := db.DB(); db != nil {
|
||||||
_ = db.Close()
|
_ = db.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
|
|||||||
// DB returns `*sql.DB`
|
// DB returns `*sql.DB`
|
||||||
func (db *DB) DB() (*sql.DB, error) {
|
func (db *DB) DB() (*sql.DB, error) {
|
||||||
connPool := db.ConnPool
|
connPool := db.ConnPool
|
||||||
|
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
connPool = db.Statement.ConnPool
|
||||||
return connector.GetDBConnWithContext(db)
|
}
|
||||||
|
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||||
|
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
@ -77,12 +77,6 @@ type GetDBConnector interface {
|
|||||||
GetDBConn() (*sql.DB, error)
|
GetDBConn() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
|
||||||
// account the current database context
|
|
||||||
type GetDBConnectorWithContext interface {
|
|
||||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rows rows interface
|
// Rows rows interface
|
||||||
type Rows interface {
|
type Rows interface {
|
||||||
Columns() ([]string, error)
|
Columns() ([]string, error)
|
||||||
|
@ -30,8 +30,10 @@ func isPrintable(s string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A list of Go types that should be converted to SQL primitives
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
// RegEx matches only numeric values
|
||||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||||
|
|
||||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
|
@ -102,6 +102,12 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, r := range results {
|
for idx, r := range results {
|
||||||
|
@ -16,8 +16,17 @@ import (
|
|||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||||
|
// with a possible trailing non-digit character (\D?).
|
||||||
|
|
||||||
|
// For example, values that can pass this regular expression are:
|
||||||
|
// - "123"
|
||||||
|
// - "abc456"
|
||||||
|
// -"%$#@789"
|
||||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||||
|
|
||||||
|
// TODO:? Create const vars for raw sql queries ?
|
||||||
|
|
||||||
// Migrator m struct
|
// Migrator m struct
|
||||||
type Migrator struct {
|
type Migrator struct {
|
||||||
Config
|
Config
|
||||||
|
@ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
|
||||||
return dbConnector.GetDBConn()
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||||
return sqldb, nil
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
return dbConnector.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidDB
|
return nil, ErrInvalidDB
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) Reset() {
|
func (sdb *PreparedStmtDB) Reset() {
|
||||||
db.Mux.Lock()
|
sdb.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer sdb.Mux.Unlock()
|
||||||
|
|
||||||
for _, stmt := range db.Stmts {
|
for _, stmt := range sdb.Stmts {
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
}
|
}
|
||||||
db.PreparedSQL = make([]string, 0, 100)
|
sdb.PreparedSQL = make([]string, 0, 100)
|
||||||
db.Stmts = make(map[string]*Stmt)
|
sdb.Stmts = make(map[string]*Stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||||
@ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
|||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
connPool, err := beginner.BeginTx(ctx, opt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tx, ok := connPool.(Tx); ok {
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||||
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,6 +185,10 @@ type PreparedStmtTX struct {
|
|||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||||
|
return db.PreparedStmtDB.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) Commit() error {
|
func (tx *PreparedStmtTX) Commit() error {
|
||||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||||
return tx.Tx.Commit()
|
return tx.Tx.Commit()
|
||||||
|
@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
|||||||
notZero, zero bool
|
notZero, zero bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if reflectValue.Kind() == reflect.Ptr ||
|
||||||
|
reflectValue.Kind() == reflect.Interface {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||||
|
@ -22,9 +22,9 @@ require (
|
|||||||
github.com/jackc/pgx/v5 v5.4.2 // indirect
|
github.com/jackc/pgx/v5 v5.4.2 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||||
github.com/microsoft/go-mssqldb v1.4.0 // indirect
|
github.com/microsoft/go-mssqldb v1.5.0 // indirect
|
||||||
golang.org/x/crypto v0.11.0 // indirect
|
golang.org/x/crypto v0.12.0 // indirect
|
||||||
golang.org/x/text v0.11.0 // indirect
|
golang.org/x/text v0.12.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
Loading…
x
Reference in New Issue
Block a user