Merge branch 'master' into distinguish_unique

# Conflicts:
#	tests/go.mod
#	tests/migrate_test.go
This commit is contained in:
black 2023-10-13 18:11:31 +08:00
commit cf41030029
25 changed files with 385 additions and 93 deletions

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: labeler
uses: jinzhu/super-labeler-action@develop

View File

@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2

View File

@ -16,7 +16,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.19', '1.18']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
@ -27,7 +27,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v3
@ -42,7 +42,7 @@ jobs:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7']
go: ['1.19', '1.18']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -70,7 +70,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v3
@ -85,7 +85,7 @@ jobs:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: [ '1.19', '1.18' ]
go: ['1.21', '1.20', '1.19']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
@ -113,7 +113,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v3
@ -128,7 +128,7 @@ jobs:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.19', '1.18']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
@ -156,7 +156,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v3
@ -170,7 +170,7 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.19', '1.18']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
@ -199,7 +199,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v3
@ -214,7 +214,7 @@ jobs:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: [ '1.19', '1.18' ]
go: ['1.21', '1.20', '1.19']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
@ -231,7 +231,7 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: go mod package cache

View File

@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = []byte{}
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
@ -246,8 +246,11 @@ func (eq Eq) Build(builder Builder) {
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" IN (")
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
@ -255,6 +258,7 @@ func (eq Eq) Build(builder Builder) {
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")

View File

@ -199,6 +199,11 @@ func TestExpression(t *testing.T) {
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},

2
go.mod
View File

@ -1,6 +1,6 @@
module gorm.io/gorm
go 1.16
go 1.18
require (
github.com/jinzhu/inflection v1.0.0

2
go.sum
View File

@ -1,6 +1,4 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

12
gorm.go
View File

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
err = config.Dialector.Initialize(db)
if err != nil {
if db, err := db.DB(); err == nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
}
@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(db)
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
@ -404,6 +407,7 @@ func (db *DB) getInstance() *DB {
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
} else {
// with clone statement

View File

@ -77,12 +77,6 @@ type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// GetDBConnectorWithContext represents SQL db connector which takes into
// account the current database context
type GetDBConnectorWithContext interface {
GetDBConnWithContext(db *DB) (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)

View File

@ -28,8 +28,10 @@ func isPrintable(s string) bool {
return true
}
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
@ -93,8 +95,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v)
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
default:

View File

@ -57,43 +57,55 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
}

View File

@ -16,8 +16,17 @@ import (
"gorm.io/gorm/schema"
)
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
// with a possible trailing non-digit character (\D?).
// For example, values that can pass this regular expression are:
// - "123"
// - "abc456"
// -"%$#@789"
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
// Migrator m struct
type Migrator struct {
Config
@ -209,7 +218,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}

View File

@ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
@ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() {
}
}
func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range db.Stmts {
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = make(map[string]*Stmt)
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
@ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
@ -172,6 +185,10 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()

View File

@ -13,6 +13,20 @@ import (
"gorm.io/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
}
}
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
@ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {

View File

@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
notZero, zero bool
)
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))}

View File

@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
}
}
// only sqlite, postgres support returning
// only sqlite, postgres, sqlserver support returning
func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
return
}
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
}
func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
return
}

View File

@ -44,6 +44,8 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
return
}
DB.Migrator().DropTable(&City{})
if err = db.AutoMigrate(&City{}); err != nil {
t.Fatalf("failed to migrate cities table, got error: %v", err)
}
@ -58,3 +60,52 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
}
}
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
tidbSkip(t, "not support the foreign key feature")
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
type Museum struct {
gorm.Model
Name string `gorm:"unique"`
CityID uint
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
}
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
DB.Migrator().DropTable(&City{}, &Museum{})
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
}
city := City{Name: "Amsterdam"}
err = db.Create(&city).Error
if err != nil {
t.Fatalf("failed to create city: %v", err)
}
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
if err != nil {
t.Fatalf("failed to create museum: %v", err)
}
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
}
}

View File

@ -1,19 +1,30 @@
module gorm.io/gorm/tests
go 1.16
go 1.18
require (
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/google/uuid v1.3.1
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/stretchr/testify v1.8.1
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.5.1
gorm.io/gorm v1.25.1
github.com/lib/pq v1.10.9
gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0
gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196
gorm.io/driver/sqlite v1.5.4
gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a
gorm.io/gorm v1.25.4
)
require (
github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/microsoft/go-mssqldb v1.6.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/text v0.13.0 // indirect
)
replace gorm.io/gorm => ../

View File

@ -265,6 +265,10 @@ func isTiDB() bool {
return os.Getenv("GORM_DIALECT") == "tidb"
}
func isMysql() bool {
return os.Getenv("GORM_DIALECT") == "mysql"
}
func db(unscoped bool) *gorm.DB {
if unscoped {
return DB.Unscoped()

View File

@ -868,6 +868,48 @@ func TestMigrateWithSpecialName(t *testing.T) {
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
}
// https://github.com/go-gorm/gorm/issues/4760
func TestMigrateAutoIncrement(t *testing.T) {
type AutoIncrementStruct struct {
ID int64 `gorm:"primarykey;autoIncrement"`
Field1 uint32 `gorm:"column:field1"`
Field2 float32 `gorm:"column:field2"`
}
if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil {
t.Fatalf("AutoMigrate err: %v", err)
}
const ROWS = 10
for idx := 0; idx < ROWS; idx++ {
if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil {
t.Fatalf("create auto_increment_struct fail, err: %v", err)
}
}
rows := make([]*AutoIncrementStruct, 0, ROWS)
if err := DB.Order("id ASC").Find(&rows).Error; err != nil {
t.Fatalf("find auto_increment_struct fail, err: %v", err)
}
ids := make([]int64, 0, len(rows))
for _, row := range rows {
ids = append(ids, row.ID)
}
lastID := ids[len(ids)-1]
if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil {
t.Fatalf("delete auto_increment_struct fail, err: %v", err)
}
newRow := &AutoIncrementStruct{}
if err := DB.Create(newRow).Error; err != nil {
t.Fatalf("create auto_increment_struct fail, err: %v", err)
}
AssertEqual(t, newRow.ID, lastID+1)
}
// https://github.com/go-gorm/gorm/issues/5320
func TestPrimarykeyID(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
@ -1607,6 +1649,51 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) {
}
}
func TestTableType(t *testing.T) {
// currently it is only supported for mysql driver
if !isMysql() {
return
}
const tblName = "cities"
const tblSchema = "gorm"
const tblType = "BASE TABLE"
const tblComment = "foobar comment"
type City struct {
gorm.Model
Name string `gorm:"unique"`
}
DB.Migrator().DropTable(&City{})
if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
t.Fatalf("failed to migrate cities tables, got error: %v", err)
}
tableType, err := DB.Table("cities").Migrator().TableType(&City{})
if err != nil {
t.Fatalf("failed to get table type, got error %v", err)
}
if tableType.Schema() != tblSchema {
t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema())
}
if tableType.Name() != tblName {
t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name())
}
if tableType.Type() != tblType {
t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type())
}
comment, ok := tableType.Comment()
if !ok || comment != tblComment {
t.Fatalf("expected comment %s got %s", tblComment, comment)
}
}
func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
const table = "unique_struct"

View File

@ -2,6 +2,7 @@ package tests_test
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
@ -658,6 +659,18 @@ func TestOrWithAllFields(t *testing.T) {
}
}
type Int64 int64
func (v Int64) Value() (driver.Value, error) {
return v - 1, nil
}
func (f *Int64) Scan(v interface{}) error {
y := v.(int64)
*f = Int64(y + 1)
return nil
}
func TestPluck(t *testing.T) {
users := []*User{
GetUser("pluck-user1", Config{}),
@ -685,6 +698,11 @@ func TestPluck(t *testing.T) {
t.Errorf("got error when pluck id: %v", err)
}
var ids2 []Int64
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil {
t.Errorf("got error when pluck id: %v", err)
}
for idx, name := range names {
if name != users[idx].Name {
t.Errorf("Unexpected result on pluck name, got %+v", names)
@ -697,6 +715,12 @@ func TestPluck(t *testing.T) {
}
}
for idx, id := range ids2 {
if int(id) != int(users[idx].ID+1) {
t.Errorf("Unexpected result on pluck id, got %+v", ids)
}
}
var times []time.Time
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &times).Error; err != nil {
t.Errorf("got error when pluck time: %v", err)

View File

@ -89,7 +89,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg)
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
}
if err != nil {

View File

@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) {
}
}
// only sqlite, postgres support returning
// only sqlite, postgres, sqlserver support returning
func TestUpdateReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
return
}

View File

@ -89,20 +89,29 @@ func Contains(elems []string, elem string) bool {
return false
}
func AssertEqual(src, dst interface{}) bool {
if !reflect.DeepEqual(src, dst) {
if valuer, ok := src.(driver.Valuer); ok {
src, _ = valuer.Value()
}
if valuer, ok := dst.(driver.Valuer); ok {
dst, _ = valuer.Value()
}
return reflect.DeepEqual(src, dst)
}
func AssertEqual(x, y interface{}) bool {
if reflect.DeepEqual(x, y) {
return true
}
if x == nil || y == nil {
return false
}
xval := reflect.ValueOf(x)
yval := reflect.ValueOf(y)
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
yval.Kind() == reflect.Ptr && yval.IsNil() {
return false
}
if valuer, ok := x.(driver.Valuer); ok {
x, _ = valuer.Value()
}
if valuer, ok := y.(driver.Valuer); ok {
y, _ = valuer.Value()
}
return reflect.DeepEqual(x, y)
}
func ToString(value interface{}) string {
switch v := value.(type) {

View File

@ -98,6 +98,7 @@ func TestAssertEqual(t *testing.T) {
{"error not equal", errors.New("1"), errors.New("2"), false},
{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true},
{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false},
{"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false},
}
for _, test := range assertEqualTests {
t.Run(test.name, func(t *testing.T) {