Merge branch 'master' into create_on_confilct
# Conflicts: # tests/create_test.go
This commit is contained in:
commit
441092d82f
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
|||||||
name: Label issues and pull requests
|
name: Label issues and pull requests
|
||||||
steps:
|
steps:
|
||||||
- name: check out
|
- name: check out
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: labeler
|
- name: labeler
|
||||||
uses: jinzhu/super-labeler-action@develop
|
uses: jinzhu/super-labeler-action@develop
|
||||||
|
2
.github/workflows/reviewdog.yml
vendored
2
.github/workflows/reviewdog.yml
vendored
@ -6,7 +6,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: reviewdog/action-golangci-lint@v2
|
uses: reviewdog/action-golangci-lint@v2
|
||||||
|
|
||||||
|
24
.github/workflows/tests.yml
vendored
24
.github/workflows/tests.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
sqlite:
|
sqlite:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.19', '1.18']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run in windows OS
|
platform: [ubuntu-latest] # can not run in windows OS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -42,7 +42,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['mysql:latest', 'mysql:5.7']
|
dbversion: ['mysql:latest', 'mysql:5.7']
|
||||||
go: ['1.19', '1.18']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest]
|
platform: [ubuntu-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -85,7 +85,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: [ 'mariadb:latest' ]
|
dbversion: [ 'mariadb:latest' ]
|
||||||
go: [ '1.19', '1.18' ]
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ ubuntu-latest ]
|
platform: [ ubuntu-latest ]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -128,7 +128,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||||
go: ['1.19', '1.18']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -156,7 +156,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -170,7 +170,7 @@ jobs:
|
|||||||
sqlserver:
|
sqlserver:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.19', '1.18']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
@ -214,7 +214,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: [ 'v6.5.0' ]
|
dbversion: [ 'v6.5.0' ]
|
||||||
go: [ '1.19', '1.18' ]
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ ubuntu-latest ]
|
platform: [ ubuntu-latest ]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ jobs:
|
|||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
|
@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
|
|
||||||
© Jinzhu, 2013~time.Now
|
© Jinzhu, 2013~time.Now
|
||||||
|
|
||||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||||
|
@ -102,13 +102,53 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
if db.RowsAffected == 0 {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
return
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
}
|
||||||
insertID, err := result.LastInsertId()
|
|
||||||
insertOk := err == nil && insertID > 0
|
var (
|
||||||
if !insertOk {
|
pkField *schema.Field
|
||||||
db.AddError(err)
|
pkFieldName = "@id"
|
||||||
|
)
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||||
|
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
insertID, err := result.LastInsertId()
|
||||||
|
insertOk := err == nil && insertID > 0
|
||||||
|
if !insertOk {
|
||||||
|
db.AddError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// append @id column with value for auto-increment primary key
|
||||||
|
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||||
|
switch values := db.Statement.Dest.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
values[pkFieldName] = insertID
|
||||||
|
case *map[string]interface{}:
|
||||||
|
(*values)[pkFieldName] = insertID
|
||||||
|
case []map[string]interface{}, *[]map[string]interface{}:
|
||||||
|
mapValues, ok := values.([]map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||||
|
if *v != nil {
|
||||||
|
mapValues = *v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, mapValue := range mapValues {
|
||||||
|
if mapValue != nil {
|
||||||
|
mapValue[pkFieldName] = insertID
|
||||||
|
}
|
||||||
|
insertID += schema.DefaultAutoIncrementIncrement
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if pkField == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,10 +161,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID -= pkField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -134,16 +174,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID += pkField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
|||||||
for _, v := range []byte(expr.SQL) {
|
for _, v := range []byte(expr.SQL) {
|
||||||
if v == '@' && !inName {
|
if v == '@' && !inName {
|
||||||
inName = true
|
inName = true
|
||||||
name = []byte{}
|
name = name[:0]
|
||||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||||
if inName {
|
if inName {
|
||||||
if nv, ok := namedMap[string(name)]; ok {
|
if nv, ok := namedMap[string(name)]; ok {
|
||||||
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
|
|||||||
|
|
||||||
switch eq.Value.(type) {
|
switch eq.Value.(type) {
|
||||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||||
builder.WriteString(" IN (")
|
|
||||||
rv := reflect.ValueOf(eq.Value)
|
rv := reflect.ValueOf(eq.Value)
|
||||||
for i := 0; i < rv.Len(); i++ {
|
if rv.Len() == 0 {
|
||||||
if i > 0 {
|
builder.WriteString(" IN (NULL)")
|
||||||
builder.WriteByte(',')
|
} else {
|
||||||
|
builder.WriteString(" IN (")
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
builder.AddVar(builder, rv.Index(i).Interface())
|
||||||
}
|
}
|
||||||
builder.AddVar(builder, rv.Index(i).Interface())
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
builder.WriteByte(')')
|
|
||||||
default:
|
default:
|
||||||
if eqNil(eq.Value) {
|
if eqNil(eq.Value) {
|
||||||
builder.WriteString(" IS NULL")
|
builder.WriteString(" IS NULL")
|
||||||
|
@ -199,6 +199,11 @@ func TestExpression(t *testing.T) {
|
|||||||
},
|
},
|
||||||
ExpectedVars: []interface{}{"a", "b"},
|
ExpectedVars: []interface{}{"a", "b"},
|
||||||
Result: "`column-name` NOT IN (?,?)",
|
Result: "`column-name` NOT IN (?,?)",
|
||||||
|
}, {
|
||||||
|
Expressions: []clause.Expression{
|
||||||
|
clause.Eq{Column: column, Value: []string{}},
|
||||||
|
},
|
||||||
|
Result: "`column-name` IN (NULL)",
|
||||||
}, {
|
}, {
|
||||||
Expressions: []clause.Expression{
|
Expressions: []clause.Expression{
|
||||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||||
|
2
go.mod
2
go.mod
@ -1,6 +1,6 @@
|
|||||||
module gorm.io/gorm
|
module gorm.io/gorm
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jinzhu/inflection v1.0.0
|
github.com/jinzhu/inflection v1.0.0
|
||||||
|
2
go.sum
2
go.sum
@ -1,6 +1,4 @@
|
|||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
22
gorm.go
22
gorm.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -181,7 +182,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
err = config.Dialector.Initialize(db)
|
err = config.Dialector.Initialize(db)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if db, err := db.DB(); err == nil {
|
if db, _ := db.DB(); db != nil {
|
||||||
_ = db.Close()
|
_ = db.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
|
|||||||
// DB returns `*sql.DB`
|
// DB returns `*sql.DB`
|
||||||
func (db *DB) DB() (*sql.DB, error) {
|
func (db *DB) DB() (*sql.DB, error) {
|
||||||
connPool := db.ConnPool
|
connPool := db.ConnPool
|
||||||
|
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
connPool = db.Statement.ConnPool
|
||||||
return connector.GetDBConnWithContext(db)
|
}
|
||||||
|
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||||
|
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
@ -399,11 +402,12 @@ func (db *DB) getInstance() *DB {
|
|||||||
if db.clone == 1 {
|
if db.clone == 1 {
|
||||||
// clone with new statement
|
// clone with new statement
|
||||||
tx.Statement = &Statement{
|
tx.Statement = &Statement{
|
||||||
DB: tx,
|
DB: tx,
|
||||||
ConnPool: db.Statement.ConnPool,
|
ConnPool: db.Statement.ConnPool,
|
||||||
Context: db.Statement.Context,
|
Context: db.Statement.Context,
|
||||||
Clauses: map[string]clause.Clause{},
|
Clauses: map[string]clause.Clause{},
|
||||||
Vars: make([]interface{}, 0, 8),
|
Vars: make([]interface{}, 0, 8),
|
||||||
|
SkipHooks: db.Statement.SkipHooks,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// with clone statement
|
// with clone statement
|
||||||
|
@ -77,12 +77,6 @@ type GetDBConnector interface {
|
|||||||
GetDBConn() (*sql.DB, error)
|
GetDBConn() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
|
||||||
// account the current database context
|
|
||||||
type GetDBConnectorWithContext interface {
|
|
||||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rows rows interface
|
// Rows rows interface
|
||||||
type Rows interface {
|
type Rows interface {
|
||||||
Columns() ([]string, error)
|
Columns() ([]string, error)
|
||||||
|
@ -69,7 +69,7 @@ type Interface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Discard Discard logger will print any log to io.Discard
|
// Discard logger will print any log to io.Discard
|
||||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||||
// Default Default logger
|
// Default Default logger
|
||||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
@ -78,7 +78,7 @@ var (
|
|||||||
IgnoreRecordNotFoundError: false,
|
IgnoreRecordNotFoundError: false,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
})
|
})
|
||||||
// Recorder Recorder logger records running SQL into a recorder instance
|
// Recorder logger records running SQL into a recorder instance
|
||||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Info print info
|
// Info print info
|
||||||
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Info {
|
if l.LogLevel >= Info {
|
||||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn print warn messages
|
// Warn print warn messages
|
||||||
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Warn {
|
if l.LogLevel >= Warn {
|
||||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error print error messages
|
// Error print error messages
|
||||||
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Error {
|
if l.LogLevel >= Error {
|
||||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trace print sql message
|
// Trace print sql message
|
||||||
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
//
|
||||||
|
//nolint:cyclop
|
||||||
|
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
if l.LogLevel <= Silent {
|
if l.LogLevel <= Silent {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trace print sql message
|
// ParamsFilter filter params
|
||||||
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||||
if l.Config.ParameterizedQueries {
|
if l.Config.ParameterizedQueries {
|
||||||
return sql, nil
|
return sql, nil
|
||||||
}
|
}
|
||||||
@ -198,8 +200,8 @@ type traceRecorder struct {
|
|||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// New new trace recorder
|
// New trace recorder
|
||||||
func (l traceRecorder) New() *traceRecorder {
|
func (l *traceRecorder) New() *traceRecorder {
|
||||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,8 +28,10 @@ func isPrintable(s string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A list of Go types that should be converted to SQL primitives
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
// RegEx matches only numeric values
|
||||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||||
|
|
||||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
@ -93,8 +95,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
vars[idx] = utils.ToString(v)
|
vars[idx] = utils.ToString(v)
|
||||||
case float64, float32:
|
case float32:
|
||||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||||
|
case float64:
|
||||||
|
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||||
case string:
|
case string:
|
||||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||||
default:
|
default:
|
||||||
|
@ -57,43 +57,55 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
NumericRegexp: nil,
|
NumericRegexp: nil,
|
||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
NumericRegexp: nil,
|
NumericRegexp: nil,
|
||||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
||||||
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
||||||
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
||||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
NumericRegexp: nil,
|
NumericRegexp: nil,
|
||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
||||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
NumericRegexp: nil,
|
NumericRegexp: nil,
|
||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,6 +87,8 @@ type Migrator interface {
|
|||||||
DropColumn(dst interface{}, field string) error
|
DropColumn(dst interface{}, field string) error
|
||||||
AlterColumn(dst interface{}, field string) error
|
AlterColumn(dst interface{}, field string) error
|
||||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||||
|
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||||
|
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||||
HasColumn(dst interface{}, field string) bool
|
HasColumn(dst interface{}, field string) bool
|
||||||
RenameColumn(dst interface{}, oldName, field string) error
|
RenameColumn(dst interface{}, oldName, field string) error
|
||||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||||
|
@ -16,8 +16,19 @@ import (
|
|||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||||
|
// with a possible trailing non-digit character (\D?).
|
||||||
|
|
||||||
|
// For example, values that can pass this regular expression are:
|
||||||
|
// - "123"
|
||||||
|
// - "abc456"
|
||||||
|
// -"%$#@789"
|
||||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||||
|
|
||||||
|
// TODO:? Create const vars for raw sql queries ?
|
||||||
|
|
||||||
|
var _ gorm.Migrator = (*Migrator)(nil)
|
||||||
|
|
||||||
// Migrator m struct
|
// Migrator m struct
|
||||||
type Migrator struct {
|
type Migrator struct {
|
||||||
Config
|
Config
|
||||||
@ -208,7 +219,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
field := stmt.Schema.FieldsByDBName[dbName]
|
field := stmt.Schema.FieldsByDBName[dbName]
|
||||||
if !field.IgnoreMigration {
|
if !field.IgnoreMigration {
|
||||||
createTableSQL += "? ?"
|
createTableSQL += "? ?"
|
||||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
|
||||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||||
createTableSQL += ","
|
createTableSQL += ","
|
||||||
}
|
}
|
||||||
@ -530,6 +541,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||||
|
unique, ok := columnType.Unique()
|
||||||
|
if !ok || field.PrimaryKey {
|
||||||
|
return nil // skip primary key
|
||||||
|
}
|
||||||
|
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
|
// We're currently only receiving boolean values on `Unique` tag,
|
||||||
|
// so the UniqueConstraint name is fixed
|
||||||
|
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||||
|
if unique && !field.Unique {
|
||||||
|
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||||
|
}
|
||||||
|
if !unique && field.Unique {
|
||||||
|
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||||
columnTypes := make([]gorm.ColumnType, 0)
|
columnTypes := make([]gorm.ColumnType, 0)
|
||||||
|
@ -31,14 +31,14 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
|
||||||
return dbConnector.GetDBConn()
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||||
return sqldb, nil
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
return dbConnector.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidDB
|
return nil, ErrInvalidDB
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,15 +54,15 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) Reset() {
|
func (sdb *PreparedStmtDB) Reset() {
|
||||||
db.Mux.Lock()
|
sdb.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer sdb.Mux.Unlock()
|
||||||
|
|
||||||
for _, stmt := range db.Stmts {
|
for _, stmt := range sdb.Stmts {
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
}
|
}
|
||||||
db.PreparedSQL = make([]string, 0, 100)
|
sdb.PreparedSQL = make([]string, 0, 100)
|
||||||
db.Stmts = make(map[string]*Stmt)
|
sdb.Stmts = make(map[string]*Stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||||
@ -127,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
|||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
connPool, err := beginner.BeginTx(ctx, opt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tx, ok := connPool.(Tx); ok {
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||||
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,6 +185,10 @@ type PreparedStmtTX struct {
|
|||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||||
|
return db.PreparedStmtDB.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) Commit() error {
|
func (tx *PreparedStmtTX) Commit() error {
|
||||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||||
return tx.Tx.Commit()
|
return tx.Tx.Commit()
|
||||||
|
@ -49,6 +49,8 @@ const (
|
|||||||
Bytes DataType = "bytes"
|
Bytes DataType = "bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultAutoIncrementIncrement int64 = 1
|
||||||
|
|
||||||
// Field is the representation of model schema's field
|
// Field is the representation of model schema's field
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
@ -119,7 +121,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||||
Comment: tagSetting["COMMENT"],
|
Comment: tagSetting["COMMENT"],
|
||||||
AutoIncrementIncrement: 1,
|
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||||
|
@ -19,6 +19,7 @@ type Namer interface {
|
|||||||
RelationshipFKName(Relationship) string
|
RelationshipFKName(Relationship) string
|
||||||
CheckerName(table, column string) string
|
CheckerName(table, column string) string
|
||||||
IndexName(table, column string) string
|
IndexName(table, column string) string
|
||||||
|
UniqueName(table, column string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replacer replacer interface like strings.Replacer
|
// Replacer replacer interface like strings.Replacer
|
||||||
@ -26,6 +27,8 @@ type Replacer interface {
|
|||||||
Replace(name string) string
|
Replace(name string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Namer = (*NamingStrategy)(nil)
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
type NamingStrategy struct {
|
type NamingStrategy struct {
|
||||||
TablePrefix string
|
TablePrefix string
|
||||||
@ -85,6 +88,11 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
|||||||
return ns.formatName("idx", table, ns.toDBName(column))
|
return ns.formatName("idx", table, ns.toDBName(column))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UniqueName generate unique constraint name
|
||||||
|
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||||
|
return ns.formatName("uni", table, ns.toDBName(column))
|
||||||
|
}
|
||||||
|
|
||||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||||
prefix, table, name,
|
prefix, table, name,
|
||||||
|
@ -13,6 +13,20 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type callbackType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||||
|
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||||
|
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||||
|
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||||
|
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||||
|
callbackTypeAfterSave callbackType = "AfterSave"
|
||||||
|
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||||
|
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||||
|
callbackTypeAfterFind callbackType = "AfterFind"
|
||||||
|
)
|
||||||
|
|
||||||
// ErrUnsupportedDataType unsupported data type
|
// ErrUnsupportedDataType unsupported data type
|
||||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||||
|
|
||||||
@ -288,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
callbackTypes := []callbackType{
|
||||||
for _, name := range callbacks {
|
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||||
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
|
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||||
|
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||||
|
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||||
|
callbackTypeAfterFind,
|
||||||
|
}
|
||||||
|
for _, cbName := range callbackTypes {
|
||||||
|
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||||
switch methodValue.Type().String() {
|
switch methodValue.Type().String() {
|
||||||
case "func(*gorm.DB) error": // TODO hack
|
case "func(*gorm.DB) error": // TODO hack
|
||||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||||
default:
|
default:
|
||||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
|
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -349,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This unrolling is needed to show to the compiler the exact set of methods
|
||||||
|
// that can be used on the modelType.
|
||||||
|
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||||
|
// abandon dead code elimination for the entire binary.
|
||||||
|
// As of go1.22 the compiler supports one special case of a string constant
|
||||||
|
// being passed to MethodByName. For enterprise customers or those building
|
||||||
|
// large binaries, this gives a significant reduction in binary size.
|
||||||
|
// https://github.com/golang/go/issues/62257
|
||||||
|
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||||
|
switch cbType {
|
||||||
|
case callbackTypeBeforeCreate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||||
|
case callbackTypeAfterCreate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||||
|
case callbackTypeBeforeUpdate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||||
|
case callbackTypeAfterUpdate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||||
|
case callbackTypeBeforeSave:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||||
|
case callbackTypeAfterSave:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||||
|
case callbackTypeBeforeDelete:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||||
|
case callbackTypeAfterDelete:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||||
|
case callbackTypeAfterFind:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||||
|
default:
|
||||||
|
return reflect.ValueOf(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
modelType := reflect.ValueOf(dest).Type()
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
|
@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
|||||||
notZero, zero bool
|
notZero, zero bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if reflectValue.Kind() == reflect.Ptr ||
|
||||||
|
reflectValue.Kind() == reflect.Interface {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||||
|
@ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) {
|
|||||||
t.Fatalf("failed to create items, got error: %v", err)
|
t.Fatalf("failed to create items, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx = tx.Debug()
|
|
||||||
|
|
||||||
// test replace
|
// test replace
|
||||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
||||||
Logo: "updated logo",
|
Logo: "updated logo",
|
||||||
|
@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var count2 int64
|
var count2 int64
|
||||||
if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
|
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
|
||||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
}
|
}
|
||||||
if count2 != 2 {
|
if count2 != 2 {
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -648,3 +649,180 @@ func TestCreateOnConflictWithDefaultJSON(t *testing.T) {
|
|||||||
AssertEqual(t, err, nil)
|
AssertEqual(t, err, nil)
|
||||||
AssertEqual(t, v2.Params, datatypes.JSONMap{"foo": "new-bar"})
|
AssertEqual(t, v2.Params, datatypes.JSONMap{"foo": "new-bar"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||||
|
if !isMysql() {
|
||||||
|
t.Skipf("This test case skipped, because of only supportting for mysql")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 1: one record, create from map[string]interface{}
|
||||||
|
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
|
||||||
|
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := mapValue1["id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result1 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var idVal int64
|
||||||
|
_, ok := mapValue1["id"].(uint)
|
||||||
|
if ok {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
idVal, ok = mapValue1["id"].(int64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ret result missing id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(result1.ID) != idVal {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case2: one record, create from *map[string]interface{}
|
||||||
|
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
|
||||||
|
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := mapValue2["id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = mapValue2["id"].(uint)
|
||||||
|
if ok {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
idVal, ok = mapValue2["id"].(int64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ret result missing id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(result2.ID) != idVal {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 3: records
|
||||||
|
values := []map[string]interface{}{
|
||||||
|
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeLen := len(values)
|
||||||
|
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mariadb with returning, values will be appended with id map
|
||||||
|
if len(values) == beforeLen*2 {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range values {
|
||||||
|
v, ok := values[i]["id"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
if int64(result.ID) != v.(int64) {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromMapWithTable(t *testing.T) {
|
||||||
|
if !isMysql() {
|
||||||
|
t.Skipf("This test case skipped, because of only supportting for mysql")
|
||||||
|
}
|
||||||
|
tableDB := DB.Table("`users`")
|
||||||
|
|
||||||
|
// case 1: create from map[string]interface{}
|
||||||
|
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
|
||||||
|
if err := tableDB.Create(record).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := record["@id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res map[string]interface{}
|
||||||
|
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(res["id"].(uint64)) != record["@id"] {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 2: create from *map[string]interface{}
|
||||||
|
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
|
||||||
|
tableDB2 := DB.Table("users")
|
||||||
|
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := record1["@id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res1 map[string]interface{}
|
||||||
|
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(res1["id"].(uint64)) != record1["@id"] {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 3: create from []map[string]interface{}
|
||||||
|
records := []map[string]interface{}{
|
||||||
|
{"name": "create_from_map_with_table_2", "age": 19},
|
||||||
|
{"name": "create_from_map_with_table_3", "age": 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
tableDB = DB.Table("users")
|
||||||
|
if err := tableDB.Create(&records).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[0]["@id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[1]["@id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res2 map[string]interface{}
|
||||||
|
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var res3 map[string]interface{}
|
||||||
|
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(res2["id"].(uint64)) != records[0]["@id"] {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(res3["id"].(uint64)) != records[1]["@id"] {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only sqlite, postgres support returning
|
// only sqlite, postgres, sqlserver support returning
|
||||||
func TestSoftDeleteReturning(t *testing.T) {
|
func TestSoftDeleteReturning(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteReturning(t *testing.T) {
|
func TestDeleteReturning(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddedRelations(t *testing.T) {
|
func TestEmbeddedRelations(t *testing.T) {
|
||||||
|
type EmbUser struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Age uint
|
||||||
|
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
|
||||||
|
}
|
||||||
|
|
||||||
type AdvancedUser struct {
|
type AdvancedUser struct {
|
||||||
User `gorm:"embedded"`
|
EmbUser `gorm:"embedded"`
|
||||||
Advanced bool
|
Advanced bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
30
tests/go.mod
30
tests/go.mod
@ -1,16 +1,32 @@
|
|||||||
module gorm.io/gorm/tests
|
module gorm.io/gorm/tests
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.4.0
|
||||||
github.com/jinzhu/now v1.1.5
|
github.com/jinzhu/now v1.1.5
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0
|
gorm.io/driver/mysql v1.5.2
|
||||||
gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196
|
gorm.io/driver/postgres v1.5.4
|
||||||
gorm.io/driver/sqlite v1.5.2
|
gorm.io/driver/sqlite v1.5.4
|
||||||
gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a
|
gorm.io/driver/sqlserver v1.5.2
|
||||||
gorm.io/gorm v1.25.2-0.20230610234218-206613868439
|
gorm.io/gorm v1.25.5
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-sql-driver/mysql v1.7.1 // indirect
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.5.0 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||||
|
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
||||||
|
golang.org/x/crypto v0.15.0 // indirect
|
||||||
|
golang.org/x/text v0.14.0 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
||||||
|
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
|
||||||
|
@ -265,6 +265,10 @@ func isTiDB() bool {
|
|||||||
return os.Getenv("GORM_DIALECT") == "tidb"
|
return os.Getenv("GORM_DIALECT") == "tidb"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isMysql() bool {
|
||||||
|
return os.Getenv("GORM_DIALECT") == "mysql"
|
||||||
|
}
|
||||||
|
|
||||||
func db(unscoped bool) *gorm.DB {
|
func db(unscoped bool) *gorm.DB {
|
||||||
if unscoped {
|
if unscoped {
|
||||||
return DB.Unscoped()
|
return DB.Unscoped()
|
||||||
|
@ -862,6 +862,48 @@ func TestMigrateWithSpecialName(t *testing.T) {
|
|||||||
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
|
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/4760
|
||||||
|
func TestMigrateAutoIncrement(t *testing.T) {
|
||||||
|
type AutoIncrementStruct struct {
|
||||||
|
ID int64 `gorm:"primarykey;autoIncrement"`
|
||||||
|
Field1 uint32 `gorm:"column:field1"`
|
||||||
|
Field2 float32 `gorm:"column:field2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const ROWS = 10
|
||||||
|
for idx := 0; idx < ROWS; idx++ {
|
||||||
|
if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil {
|
||||||
|
t.Fatalf("create auto_increment_struct fail, err: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := make([]*AutoIncrementStruct, 0, ROWS)
|
||||||
|
if err := DB.Order("id ASC").Find(&rows).Error; err != nil {
|
||||||
|
t.Fatalf("find auto_increment_struct fail, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]int64, 0, len(rows))
|
||||||
|
for _, row := range rows {
|
||||||
|
ids = append(ids, row.ID)
|
||||||
|
}
|
||||||
|
lastID := ids[len(ids)-1]
|
||||||
|
|
||||||
|
if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil {
|
||||||
|
t.Fatalf("delete auto_increment_struct fail, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newRow := &AutoIncrementStruct{}
|
||||||
|
if err := DB.Create(newRow).Error; err != nil {
|
||||||
|
t.Fatalf("create auto_increment_struct fail, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, newRow.ID, lastID+1)
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/go-gorm/gorm/issues/5320
|
// https://github.com/go-gorm/gorm/issues/5320
|
||||||
func TestPrimarykeyID(t *testing.T) {
|
func TestPrimarykeyID(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "postgres" {
|
||||||
@ -1598,3 +1640,48 @@ func TestMigrateExistingBoolColumnPG(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTableType(t *testing.T) {
|
||||||
|
// currently it is only supported for mysql driver
|
||||||
|
if !isMysql() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const tblName = "cities"
|
||||||
|
const tblSchema = "gorm"
|
||||||
|
const tblType = "BASE TABLE"
|
||||||
|
const tblComment = "foobar comment"
|
||||||
|
|
||||||
|
type City struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"unique"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&City{})
|
||||||
|
|
||||||
|
if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate cities tables, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableType, err := DB.Table("cities").Migrator().TableType(&City{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get table type, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableType.Schema() != tblSchema {
|
||||||
|
t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableType.Name() != tblName {
|
||||||
|
t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableType.Type() != tblType {
|
||||||
|
t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
comment, ok := tableType.Comment()
|
||||||
|
if !ok || comment != tblComment {
|
||||||
|
t.Fatalf("expected comment %s got %s", tblComment, comment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
DB = DB.Debug()
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
actual := Org{}
|
actual := Org{}
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -658,6 +659,18 @@ func TestOrWithAllFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Int64 int64
|
||||||
|
|
||||||
|
func (v Int64) Value() (driver.Value, error) {
|
||||||
|
return v - 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Int64) Scan(v interface{}) error {
|
||||||
|
y := v.(int64)
|
||||||
|
*f = Int64(y + 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestPluck(t *testing.T) {
|
func TestPluck(t *testing.T) {
|
||||||
users := []*User{
|
users := []*User{
|
||||||
GetUser("pluck-user1", Config{}),
|
GetUser("pluck-user1", Config{}),
|
||||||
@ -685,6 +698,11 @@ func TestPluck(t *testing.T) {
|
|||||||
t.Errorf("got error when pluck id: %v", err)
|
t.Errorf("got error when pluck id: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ids2 []Int64
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for idx, name := range names {
|
for idx, name := range names {
|
||||||
if name != users[idx].Name {
|
if name != users[idx].Name {
|
||||||
t.Errorf("Unexpected result on pluck name, got %+v", names)
|
t.Errorf("Unexpected result on pluck name, got %+v", names)
|
||||||
@ -697,6 +715,12 @@ func TestPluck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for idx, id := range ids2 {
|
||||||
|
if int(id) != int(users[idx].ID+1) {
|
||||||
|
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||||
t.Errorf("got error when pluck time: %v", err)
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
|
@ -43,9 +43,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RunMigrations()
|
RunMigrations()
|
||||||
if DB.Dialector.Name() == "sqlite" {
|
|
||||||
DB.Exec("PRAGMA foreign_keys = ON")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +86,10 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
|||||||
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
|
||||||
default:
|
default:
|
||||||
log.Println("testing sqlite3...")
|
log.Println("testing sqlite3...")
|
||||||
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
|
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg)
|
||||||
|
if err == nil {
|
||||||
|
db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only sqlite, postgres support returning
|
// only sqlite, postgres, sqlserver support returning
|
||||||
func TestUpdateReturning(t *testing.T) {
|
func TestUpdateReturning(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) {
|
|||||||
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
|
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
|
||||||
var newOwner TokenOwner
|
var newOwner TokenOwner
|
||||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
|
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {
|
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {
|
||||||
|
@ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertEqual(src, dst interface{}) bool {
|
func AssertEqual(x, y interface{}) bool {
|
||||||
if !reflect.DeepEqual(src, dst) {
|
if reflect.DeepEqual(x, y) {
|
||||||
if valuer, ok := src.(driver.Valuer); ok {
|
return true
|
||||||
src, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
if valuer, ok := dst.(driver.Valuer); ok {
|
|
||||||
dst, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.DeepEqual(src, dst)
|
|
||||||
}
|
}
|
||||||
return true
|
if x == nil || y == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
xval := reflect.ValueOf(x)
|
||||||
|
yval := reflect.ValueOf(y)
|
||||||
|
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
|
||||||
|
yval.Kind() == reflect.Ptr && yval.IsNil() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if valuer, ok := x.(driver.Valuer); ok {
|
||||||
|
x, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
if valuer, ok := y.(driver.Valuer); ok {
|
||||||
|
y, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
return reflect.DeepEqual(x, y)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToString(value interface{}) string {
|
func ToString(value interface{}) string {
|
||||||
|
@ -98,6 +98,7 @@ func TestAssertEqual(t *testing.T) {
|
|||||||
{"error not equal", errors.New("1"), errors.New("2"), false},
|
{"error not equal", errors.New("1"), errors.New("2"), false},
|
||||||
{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true},
|
{"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true},
|
||||||
{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false},
|
{"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false},
|
||||||
|
{"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false},
|
||||||
}
|
}
|
||||||
for _, test := range assertEqualTests {
|
for _, test := range assertEqualTests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user