Merge branch 'master' into detached
This commit is contained in:
commit
31d2ea8ca4
2
.github/workflows/invalid_question.yml
vendored
2
.github/workflows/invalid_question.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v8
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
||||
name: Label issues and pull requests
|
||||
steps:
|
||||
- name: check out
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: labeler
|
||||
uses: jinzhu/super-labeler-action@develop
|
||||
|
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v8
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
2
.github/workflows/reviewdog.yml
vendored
2
.github/workflows/reviewdog.yml
vendored
@ -6,7 +6,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: golangci-lint
|
||||
uses: reviewdog/action-golangci-lint@v2
|
||||
|
||||
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v5
|
||||
uses: actions/stale@v8
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||
|
111
.github/workflows/tests.yml
vendored
111
.github/workflows/tests.yml
vendored
@ -16,21 +16,21 @@ jobs:
|
||||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.18', '1.17', '1.16']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ubuntu-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -41,8 +41,8 @@ jobs:
|
||||
mysql:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||
go: ['1.18', '1.17', '1.16']
|
||||
dbversion: ['mysql:latest', 'mysql:5.7']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -65,16 +65,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -82,11 +81,54 @@ jobs:
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
mariadb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'mariadb:latest' ]
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
ports:
|
||||
- 9910:3306
|
||||
options: >-
|
||||
--health-cmd "mariadb-admin ping -ugorm -pgorm"
|
||||
--health-interval 10s
|
||||
--health-start-period 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
postgres:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||
go: ['1.18', '1.17', '1.16']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -109,15 +151,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -128,7 +170,7 @@ jobs:
|
||||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.18', '1.17', '1.16']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -152,18 +194,51 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||
|
||||
tidb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'v6.5.0' ]
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Setup TiDB
|
||||
uses: Icemap/tidb-action@main
|
||||
with:
|
||||
port: 9940
|
||||
version: ${{matrix.dbversion}}
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,4 +3,5 @@ documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
vendor
|
||||
.vscode
|
||||
|
@ -9,3 +9,12 @@ linters:
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- goimports
|
||||
- whitespace
|
||||
|
||||
linters-settings:
|
||||
whitespace:
|
||||
multi-func: true
|
||||
goimports:
|
||||
local-prefixes: gorm.io/gorm
|
||||
|
||||
|
11
README.md
11
README.md
@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[](https://github.com/go-gorm/gorm/actions)
|
||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
@ -30,14 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## Contributors
|
||||
|
||||
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||
|
||||
## License
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Unscope bool
|
||||
Error error
|
||||
}
|
||||
|
||||
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Unscoped() *Association {
|
||||
return &Association{
|
||||
DB: association.DB,
|
||||
Relationship: association.Relationship,
|
||||
Error: association.Error,
|
||||
Unscope: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
|
||||
var oldBelongsToExpr clause.Expression
|
||||
// we have to record the old BelongsTo value
|
||||
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||
}
|
||||
}
|
||||
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
if association.Unscope && oldBelongsToExpr != nil {
|
||||
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
if association.Unscope {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||
} else {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
associationDB := association.DB.Session(&Session{})
|
||||
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||
@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
||||
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := association.DB.Model(model)
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||
@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||
} else {
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
@ -353,9 +404,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
var fieldValue reflect.Value
|
||||
if clear {
|
||||
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
|
||||
} else {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
|
||||
reflect.Copy(fieldValue, oldFieldValue)
|
||||
}
|
||||
|
||||
appendToFieldValues := func(ev reflect.Value) {
|
||||
@ -507,7 +562,9 @@ func (association *Association) buildCondition() *DB {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
if len(joinStmt.SQL.String()) > 0 {
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
}
|
||||
|
||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||
|
45
callbacks.go
45
callbacks.go
@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
|
||||
func (p *processor) Execute(db *DB) *DB {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
db = db.executeScopes()
|
||||
}
|
||||
|
||||
var (
|
||||
@ -93,6 +89,10 @@ func (p *processor) Execute(db *DB) *DB {
|
||||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
@ -132,7 +132,11 @@ func (p *processor) Execute(db *DB) *DB {
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
|
||||
@ -183,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
@ -245,8 +257,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.Slice(cs, func(i, j int) bool {
|
||||
return cs[j].before == "*" || cs[j].after == "*"
|
||||
sort.SliceStable(cs, func(i, j int) bool {
|
||||
if cs[j].before == "*" && cs[i].before != "*" {
|
||||
return true
|
||||
}
|
||||
if cs[j].after == "*" && cs[i].after != "*" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
@ -329,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||
if !isPtr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
objs = append(objs, obj)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, rv)
|
||||
} else {
|
||||
elems = reflect.Append(elems, rv.Addr())
|
||||
elems = reflect.Append(elems, rv)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
@ -206,9 +221,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues)
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
identityMap[cacheKey] = true
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
@ -253,6 +271,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
@ -272,19 +291,34 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
|
||||
objs = append(objs, v)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
if !isPtr {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
objs = append(objs, v)
|
||||
elems = reflect.Append(elems, elem)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -304,7 +338,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
|
@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
|
||||
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||
fc(value.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
if db.Statement.ReflectValue.CanAddr() {
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -102,13 +103,62 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
if !insertOk {
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if !supportReturning {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -121,10 +171,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -134,16 +184,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -252,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -302,14 +354,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
71
callbacks/create_test.go
Normal file
71
callbacks/create_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var schemaCache = &sync.Map{}
|
||||
|
||||
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
|
||||
type user struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
Email string `gorm:"default:(-)"`
|
||||
Age int `gorm:"default:(-)"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("parse schema error: %v, is not expected", err)
|
||||
return
|
||||
}
|
||||
dest := []*user{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "alice",
|
||||
Email: "email",
|
||||
Age: 18,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "bob",
|
||||
Email: "email",
|
||||
Age: 19,
|
||||
},
|
||||
}
|
||||
stmt := &gorm.Statement{
|
||||
DB: &gorm.DB{
|
||||
Config: &gorm.Config{
|
||||
NowFunc: func() time.Time { return time.Time{} },
|
||||
},
|
||||
Statement: &gorm.Statement{
|
||||
Settings: sync.Map{},
|
||||
Schema: s,
|
||||
},
|
||||
},
|
||||
ReflectValue: reflect.ValueOf(dest),
|
||||
Dest: dest,
|
||||
}
|
||||
|
||||
stmt.Schema = s
|
||||
|
||||
values := ConvertToCreateValues(stmt)
|
||||
expected := clause.Values{
|
||||
// column has value + defaultValue column has value (which should have a stable order)
|
||||
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
|
||||
Values: [][]interface{}{
|
||||
{"alice", "email", 18, 1},
|
||||
{"bob", "email", 19, 2},
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(expected, values) {
|
||||
t.Errorf("expected: %v got %v", expected, values)
|
||||
}
|
||||
}
|
157
callbacks/helper_test.go
Normal file
157
callbacks/helper_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMapToValuesForCreate(t *testing.T) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
input map[string]interface{}
|
||||
expect clause.Values
|
||||
}{
|
||||
{
|
||||
name: "Test convert string value",
|
||||
input: map[string]interface{}{
|
||||
"name": "my name",
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
Values: [][]interface{}{{"my name"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert int value",
|
||||
input: map[string]interface{}{
|
||||
"age": 18,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "age"}},
|
||||
Values: [][]interface{}{{18}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert float value",
|
||||
input: map[string]interface{}{
|
||||
"score": 99.5,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "score"}},
|
||||
Values: [][]interface{}{{99.5}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert bool value",
|
||||
input: map[string]interface{}{
|
||||
"active": true,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "active"}},
|
||||
Values: [][]interface{}{{true}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||
if !reflect.DeepEqual(actual, tc.expect) {
|
||||
t.Errorf("expect %v got %v", tc.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
input []map[string]interface{}
|
||||
expect clause.Values
|
||||
}{
|
||||
{
|
||||
name: "Test convert slice of string value",
|
||||
input: []map[string]interface{}{
|
||||
{"name": "my name"},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
Values: [][]interface{}{{"my name"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of int value",
|
||||
input: []map[string]interface{}{
|
||||
{"age": 18},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "age"}},
|
||||
Values: [][]interface{}{{18}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of float value",
|
||||
input: []map[string]interface{}{
|
||||
{"score": 99.5},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "score"}},
|
||||
Values: [][]interface{}{{99.5}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of bool value",
|
||||
input: []map[string]interface{}{
|
||||
{"active": true},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "active"}},
|
||||
Values: [][]interface{}{{true}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||
|
||||
if !reflect.DeepEqual(actual, tc.expect) {
|
||||
t.Errorf("expected %v but got %v", tc.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
@ -3,6 +3,8 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -10,6 +12,164 @@ import (
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
joinNames := strings.SplitN(join, ".", 2)
|
||||
if len(joinNames) == 2 {
|
||||
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, joinNames[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
|
@ -3,11 +3,12 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
@ -109,78 +110,141 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema == nil {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
})
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
@ -189,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
db.Statement.Joins = nil
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
@ -207,60 +270,23 @@ func Preload(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
||||
if rel.Schema == db.Statement.Schema {
|
||||
if _, ok := preloadMap[rel.Name]; !ok {
|
||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
}
|
||||
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
preloadDB.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
return
|
||||
}
|
||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
}
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
db.Statement.Joins = nil
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun {
|
||||
if db.DryRun || db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -70,10 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
db.Statement.AddClause(set)
|
||||
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
return
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
@ -135,7 +138,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
@ -158,21 +163,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||
var primaryKeyExprs []clause.Expression
|
||||
var isZero bool
|
||||
for i := 0; i < size; i++ {
|
||||
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
if notZero {
|
||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||
if !isZero {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||
if !isZero {
|
||||
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
@ -229,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
@ -241,11 +246,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
}
|
||||
default:
|
||||
updatingSchema := stmt.Schema
|
||||
var isDiffSchema bool
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
// different schema
|
||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||
updatingSchema = updatingStmt.Schema
|
||||
isDiffSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -261,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
@ -272,7 +279,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
|
||||
if (ok || !isZero) && field.Updatable {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignValue(field, value)
|
||||
assignField := field
|
||||
if isDiffSchema {
|
||||
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||
assignField = originField
|
||||
}
|
||||
}
|
||||
assignValue(assignField, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -1,36 +0,0 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
189
chainable_api.go
189
chainable_api.go
@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
//
|
||||
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||
// [docs] for more depth.
|
||||
//
|
||||
// // add a simple limit clause
|
||||
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||
// // tell the optimizer to use the `idx_user_name` index
|
||||
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||
// // specify the lock strength to UPDATE
|
||||
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
||||
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
//
|
||||
// // Get a user
|
||||
// db.Table("users").Take(&result)
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
||||
tx.Statement.Table = results[1]
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||
if results[1] != "" {
|
||||
tx.Statement.Table = results[1]
|
||||
} else {
|
||||
tx.Statement.Table = results[2]
|
||||
}
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
//
|
||||
// // Select distinct names of users
|
||||
// db.Distinct("name").Find(&results)
|
||||
// // Select distinct name/age pairs from users
|
||||
// db.Distinct("name", "age").Find(&results)
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
//
|
||||
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||
// Select accepts both string arguments and arrays.
|
||||
//
|
||||
// // Select name and age of user using multiple arguments
|
||||
// db.Select("name", "age").Find(&users)
|
||||
// // Select name and age of user using an array
|
||||
// db.Select([]string{"name", "age"}).Find(&users)
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -152,6 +186,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
//
|
||||
// // Find the first user with name jinzhu
|
||||
// db.Where("name = ?", "jinzhu").First(&user)
|
||||
// // Find the first user with name jinzhu and age 20
|
||||
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||
// // Find the first user with name jinzhu and age not equal to 20
|
||||
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -161,6 +206,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
//
|
||||
// Not works similarly to where, and has the same syntax.
|
||||
//
|
||||
// // Find the first user with name not equal to jinzhu
|
||||
// db.Not("name = ?", "jinzhu").First(&user)
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -170,6 +220,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
//
|
||||
// Or is used to chain together queries with an OR.
|
||||
//
|
||||
// // Find the first user with name equal to jinzhu or john
|
||||
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -179,26 +234,45 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.LeftJoin, query, args...)
|
||||
}
|
||||
|
||||
// InnerJoins specify inner joins conditions
|
||||
// db.InnerJoins("Account").Find(&user)
|
||||
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.InnerJoin, query, args...)
|
||||
}
|
||||
|
||||
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
|
||||
return
|
||||
j := join{
|
||||
Name: query, Conds: args, Selects: db.Statement.Selects,
|
||||
Omits: db.Statement.Omits, JoinType: joinType,
|
||||
}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
//
|
||||
// // Select the sum age of users with given names
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -210,6 +284,9 @@ func (db *DB) Group(name string) (tx *DB) {
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
//
|
||||
// // Select the sum age of users with name jinzhu
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
@ -218,9 +295,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
// Order specify order when retrieving records from database
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -242,13 +320,27 @@ func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
//
|
||||
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||
//
|
||||
// // retrieve 3 users
|
||||
// db.Limit(3).Find(&users)
|
||||
// // retrieve 3 users into users1, and all users into users2
|
||||
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
//
|
||||
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||
//
|
||||
// // select the third user
|
||||
// db.Offset(2).First(&user)
|
||||
// // select the first user by cancelling an earlier chained offset
|
||||
// db.Offset(5).Offset(-1).First(&user)
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
@ -256,25 +348,37 @@ func (db *DB) Offset(offset int) (tx *DB) {
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
//
|
||||
// // get all users, and preload all non-cancelled orders
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
@ -284,12 +388,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Attrs only adds attributes if the record is not found.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign an email if the record is not found, otherwise ignore provided email
|
||||
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||
// records will be updated even if they are found.
|
||||
//
|
||||
// // assign an email regardless of if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
|
@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
|
||||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
limit10 := 10
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{
|
||||
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
|
||||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||
}},
|
||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||
clause.Limit{Limit: 10, Offset: 20},
|
||||
clause.Limit{Limit: &limit10, Offset: 20},
|
||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ type Builder interface {
|
||||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
AddError(error) error
|
||||
}
|
||||
|
||||
// Clause
|
||||
|
@ -126,8 +126,8 @@ func (expr NamedExpr) Build(builder Builder) {
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = []byte{}
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' {
|
||||
name = name[:0]
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
|
||||
|
||||
switch eq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
builder.WriteString(" IN (")
|
||||
rv := reflect.ValueOf(eq.Value)
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
if rv.Len() == 0 {
|
||||
builder.WriteString(" IN (NULL)")
|
||||
} else {
|
||||
builder.WriteString(" IN (")
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
default:
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
|
@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) {
|
||||
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
||||
Result: "name1 = ? AND name2 = ?;",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1\r\n AND name2 = @name2",
|
||||
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
|
||||
Result: "name1 = ?\r\n AND name2 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1\r AND name2 = @name2",
|
||||
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
|
||||
Result: "name1 = ?\r AND name2 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||
}, {
|
||||
SQL: "?",
|
||||
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
|
||||
@ -189,6 +199,11 @@ func TestExpression(t *testing.T) {
|
||||
},
|
||||
ExpectedVars: []interface{}{"a", "b"},
|
||||
Result: "`column-name` NOT IN (?,?)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: column, Value: []string{}},
|
||||
},
|
||||
Result: "`column-name` IN (NULL)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||
|
@ -9,7 +9,7 @@ const (
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
// Join join clause for from
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
|
101
clause/joins_test.go
Normal file
101
clause/joins_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
results := []struct {
|
||||
name string
|
||||
join clause.Join
|
||||
sql string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.RightJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "CROSS JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.CrossJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "USING",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
{
|
||||
name: "Expression",
|
||||
join: clause.Join{
|
||||
// Invalid
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
// Valid
|
||||
Expression: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
}
|
||||
for _, result := range results {
|
||||
t.Run(result.name, func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
result.join.Build(stmt)
|
||||
if result.sql != stmt.SQL.String() {
|
||||
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,10 +1,8 @@
|
||||
package clause
|
||||
|
||||
import "strconv"
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit int
|
||||
Limit *int
|
||||
Offset int
|
||||
}
|
||||
|
||||
@ -15,16 +13,16 @@ func (limit Limit) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
builder.AddVar(builder, *limit.Limit)
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
builder.AddVar(builder, limit.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,7 +31,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,10 @@ import (
|
||||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
limit0 := 0
|
||||
limit10 := 10
|
||||
limit50 := 50
|
||||
limitNeg10 := -10
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
@ -15,38 +19,56 @@ func TestLimit(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||
Limit: 10,
|
||||
Limit: &limit10,
|
||||
Offset: 20,
|
||||
}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 20},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit0},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit0},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||
"SELECT * FROM `users` OFFSET 20", nil,
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{20},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{30},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 20},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 30},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT 10", nil,
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit10},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{30},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit50, 30},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,12 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
|
@ -14,17 +14,21 @@ func TestLocking(t *testing.T) {
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
|
||||
"SELECT * FROM `users` FOR UPDATE", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
|
||||
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
|
||||
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -16,27 +16,27 @@ func (OnConflict) Name() string {
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
|
@ -49,16 +49,18 @@ func TestSelect(t *testing.T) {
|
||||
Exprs: []clause.Expression{
|
||||
clause.Expr{
|
||||
SQL: "? as name",
|
||||
Vars: []interface{}{clause.Eq{
|
||||
Column: clause.Column{Name: "age"},
|
||||
Value: 18,
|
||||
},
|
||||
Vars: []interface{}{
|
||||
clause.Eq{
|
||||
Column: clause.Column{Name: "age"},
|
||||
Value: 18,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, clause.From{}},
|
||||
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
|
||||
"SELECT `age` = ? as name FROM `users`",
|
||||
[]interface{}{18},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,12 @@ func (where Where) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
@ -155,19 +166,58 @@ type NotConditions struct {
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
anyNegationBuilder := false
|
||||
for _, c := range not.Exprs {
|
||||
if _, ok := c.(NegationExpressionBuilder); ok {
|
||||
anyNegationBuilder = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
if anyNegationBuilder {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
@ -182,9 +232,9 @@ func (not NotConditions) Build(builder Builder) {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
||||
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
|
||||
[]interface{}{18, "jinzhu"},
|
||||
},
|
||||
{
|
||||
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
|
||||
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
||||
},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
|
||||
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
|
||||
[]interface{}{"1", 100},
|
||||
},
|
||||
{
|
||||
@ -105,6 +105,14 @@ func TestWhere(t *testing.T) {
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
|
||||
[]interface{}{"1", 100},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
|
||||
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
|
||||
[]interface{}{100, 60},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -21,6 +21,10 @@ var (
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||
// ErrSubQueryRequired sub query required
|
||||
ErrSubQueryRequired = errors.New("sub query required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
@ -41,4 +45,8 @@ var (
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
)
|
||||
|
222
finisher_api.go
222
finisher_api.go
@ -13,7 +13,7 @@ import (
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Create insert the value into database
|
||||
// Create inserts value, returning the inserted data's primary key in value's id
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
if db.CreateBatchSize > 0 {
|
||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
|
||||
// CreateInBatches insert the value in batches into database
|
||||
// CreateInBatches inserts value in batches of batchSize
|
||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
@ -33,9 +33,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
var rowsAffected int64
|
||||
tx = db.getInstance()
|
||||
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
|
||||
callFc := func(tx *DB) error {
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
for i := 0; i < reflectLen; i += batchSize {
|
||||
ends := i + batchSize
|
||||
if ends > reflectLen {
|
||||
@ -53,7 +54,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx.SkipDefaultTransaction {
|
||||
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
|
||||
tx.AddError(callFc(tx.Session(&Session{})))
|
||||
} else {
|
||||
tx.AddError(tx.Transaction(callFc))
|
||||
@ -68,7 +69,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
@ -101,20 +102,19 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
tx = tx.callbacks.Update().Execute(tx)
|
||||
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
|
||||
return tx.Create(value)
|
||||
}
|
||||
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
}
|
||||
|
||||
return updateTx
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// First find first record that match given conditions, order by primary key
|
||||
// First finds the first record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1)
|
||||
if len(conds) > 0 {
|
||||
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Find find records that match given conditions
|
||||
// Find finds all records matching given conditions conds
|
||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
// FindInBatches finds all records in batches of batchSize
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||
var (
|
||||
tx = db.Order(clause.OrderByColumn{
|
||||
@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
totalSize = limit.Limit
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
@ -202,7 +204,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
batch++
|
||||
|
||||
if result.Error == nil && result.RowsAffected != 0 {
|
||||
tx.AddError(fc(result, batch))
|
||||
fcTx := result.Session(&Session{NewDB: true})
|
||||
fcTx.RowsAffected = result.RowsAffected
|
||||
tx.AddError(fc(fcTx, batch))
|
||||
} else if result.Error != nil {
|
||||
tx.AddError(result.Error)
|
||||
}
|
||||
@ -227,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
break
|
||||
}
|
||||
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
if zero {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
@ -284,7 +292,18 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
@ -310,62 +329,82 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
} else {
|
||||
tx.Error = result.Error
|
||||
}
|
||||
result := queryTx.Find(dest, conds...)
|
||||
if result.Error != nil {
|
||||
tx.Error = result.Error
|
||||
return tx
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
@ -386,7 +425,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
||||
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
||||
// time if null.
|
||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
@ -480,7 +521,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
|
||||
return rows, tx.Error
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
// Scan scans selected value to the struct dest
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
config := *db.Config
|
||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||
@ -494,6 +535,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
tx.AddError(rows.Err())
|
||||
}
|
||||
tx.AddError(rows.Close())
|
||||
}
|
||||
@ -505,9 +547,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||
//
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model != nil {
|
||||
@ -548,7 +591,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
|
||||
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
||||
// returned to the connection pool.
|
||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
if db.Error != nil {
|
||||
return db.Error
|
||||
@ -570,7 +614,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
return fc(tx)
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||
// they are rolled back.
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
|
||||
@ -581,7 +627,6 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
@ -589,8 +634,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = fc(db.Session(&Session{}))
|
||||
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
|
||||
} else {
|
||||
tx := db.Begin(opts...)
|
||||
if tx.Error != nil {
|
||||
@ -614,7 +658,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
return
|
||||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
// Begin begins a transaction with any transaction options opts
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
var (
|
||||
// clone statement
|
||||
@ -643,7 +687,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
return tx
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
// Commit commits the changes in a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Commit())
|
||||
@ -653,7 +697,7 @@ func (db *DB) Commit() *DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// Rollback rollback a transaction
|
||||
// Rollback rollbacks the changes in a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
@ -667,7 +711,21 @@ func (db *DB) Rollback() *DB {
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
@ -676,14 +734,28 @@ func (db *DB) SavePoint(name string) *DB {
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because RollbackTo not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Exec execute raw sql
|
||||
// Exec executes raw sql
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
4
go.mod
4
go.mod
@ -1,8 +1,8 @@
|
||||
module gorm.io/gorm
|
||||
|
||||
go 1.14
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.4
|
||||
github.com/jinzhu/now v1.1.5
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -1,4 +1,4 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
|
101
gorm.go
101
gorm.go
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@ -37,6 +38,8 @@ type Config struct {
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// IgnoreRelationshipsWhenMigrating
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
@ -45,6 +48,8 @@ type Config struct {
|
||||
QueryFields bool
|
||||
// CreateBatchSize default create batch size
|
||||
CreateBatchSize int
|
||||
// TranslateError enabling error translation
|
||||
TranslateError bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
@ -142,7 +147,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{}
|
||||
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
@ -175,17 +180,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
}
|
||||
|
||||
preparedStmt := &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
Stmts: map[string]Stmt{},
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
if err != nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
@ -246,16 +251,30 @@ func (db *DB) Session(config *Session) *DB {
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
var preparedStmt *PreparedStmtDB
|
||||
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
@ -300,7 +319,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
return db.Session(&Session{
|
||||
tx = db.getInstance()
|
||||
return tx.Session(&Session{
|
||||
Logger: db.Logger.LogMode(logger.Info),
|
||||
})
|
||||
}
|
||||
@ -336,10 +356,18 @@ func (db *DB) Callback() *callbacks {
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else if err != nil {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
if err != nil {
|
||||
if db.Config.TranslateError {
|
||||
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||
err = errTranslator.Translate(err)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
@ -347,12 +375,20 @@ func (db *DB) AddError(err error) error {
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||
return sqldb, err
|
||||
}
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
@ -366,11 +402,12 @@ func (db *DB) getInstance() *DB {
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
@ -412,7 +449,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
||||
relation, ok := modelSchema.Relationships.Relations[field]
|
||||
isRelation := ok && relation.JoinTable != nil
|
||||
if !isRelation {
|
||||
return fmt.Errorf("failed to found relation: %s", field)
|
||||
return fmt.Errorf("failed to find relation: %s", field)
|
||||
}
|
||||
|
||||
for _, ref := range relation.References {
|
||||
@ -455,12 +492,12 @@ func (db *DB) Use(plugin Plugin) error {
|
||||
|
||||
// ToSQL for generate SQL string.
|
||||
//
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
stmt := tx.Statement
|
||||
|
@ -26,6 +26,10 @@ type Plugin interface {
|
||||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
type ParamsFilter interface {
|
||||
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
@ -82,3 +86,7 @@ type Rows interface {
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ErrorTranslator interface {
|
||||
Translate(err error) error
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
@ -55,6 +55,7 @@ type Config struct {
|
||||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
IgnoreRecordNotFoundError bool
|
||||
ParameterizedQueries bool
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
@ -68,8 +69,8 @@ type Interface interface {
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard Discard logger will print any log to ioutil.Discard
|
||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||
// Discard logger will print any log to io.Discard
|
||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
@ -77,7 +78,7 @@ var (
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder Recorder logger records running SQL into a recorder instance
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
)
|
||||
|
||||
@ -128,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface {
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
//
|
||||
//nolint:cyclop
|
||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.LogLevel <= Silent {
|
||||
return
|
||||
}
|
||||
@ -181,6 +184,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
||||
}
|
||||
}
|
||||
|
||||
// ParamsFilter filter params
|
||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
return sql, params
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
@ -189,8 +200,8 @@ type traceRecorder struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// New new trace recorder
|
||||
func (l traceRecorder) New() *traceRecorder {
|
||||
// New trace recorder
|
||||
func (l *traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
|
@ -28,8 +28,25 @@ func isPrintable(s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of Go types that should be converted to SQL primitives
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
func isNumeric(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
@ -75,24 +92,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = utils.ToString(v)
|
||||
case float64, float32:
|
||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||
case float32:
|
||||
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case string:
|
||||
v = strings.ReplaceAll(v, "\\"+escaper, "")
|
||||
|
||||
@ -110,6 +129,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else if isNumeric(rv.Kind()) {
|
||||
if rv.CanInt() || rv.CanUint() {
|
||||
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||
} else {
|
||||
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||
}
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
@ -117,7 +142,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -144,9 +169,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
sql = newSQL.String()
|
||||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||
for idx, v := range vars {
|
||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
|
||||
}
|
||||
|
||||
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
||||
num := v[1 : len(v)-1]
|
||||
n, _ := strconv.Atoi(num)
|
||||
|
||||
// position var start from 1 ($1, $2)
|
||||
n -= 1
|
||||
if n >= 0 && n <= len(vars)-1 {
|
||||
return vars[n]
|
||||
}
|
||||
return v
|
||||
})
|
||||
}
|
||||
|
||||
return sql
|
||||
|
@ -31,20 +31,24 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
}
|
||||
|
||||
func format(v []byte, escaper string) string {
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
type role string
|
||||
type password []byte
|
||||
type intType int
|
||||
type floatType float64
|
||||
var (
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password([]byte("pass"))
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password("pass")
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
intVal intType = 1
|
||||
floatVal floatType = 1.23
|
||||
)
|
||||
|
||||
results := []struct {
|
||||
@ -69,19 +73,19 @@ func TestExplainSQL(t *testing.T) {
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
||||
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
||||
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
@ -95,6 +99,30 @@ func TestExplainSQL(t *testing.T) {
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, E"w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, r := range results {
|
||||
|
34
migrator.go
34
migrator.go
@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator {
|
||||
|
||||
// apply scopes to migrator
|
||||
for len(tx.Statement.scopes) > 0 {
|
||||
scopes := tx.Statement.scopes
|
||||
tx.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
tx = scope(tx)
|
||||
}
|
||||
tx = tx.executeScopes()
|
||||
}
|
||||
|
||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||
@ -30,9 +26,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
|
||||
// ViewOption view option
|
||||
type ViewOption struct {
|
||||
Replace bool
|
||||
CheckOption string
|
||||
Query *DB
|
||||
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
|
||||
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
|
||||
Query *DB // required subquery.
|
||||
}
|
||||
|
||||
// ColumnType column type interface
|
||||
@ -51,6 +47,23 @@ type ColumnType interface {
|
||||
DefaultValue() (value string, ok bool)
|
||||
}
|
||||
|
||||
type Index interface {
|
||||
Table() string
|
||||
Name() string
|
||||
Columns() []string
|
||||
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||
Unique() (unique bool, ok bool)
|
||||
Option() string
|
||||
}
|
||||
|
||||
// TableType table type interface
|
||||
type TableType interface {
|
||||
Schema() string
|
||||
Name() string
|
||||
Type() string
|
||||
Comment() (comment string, ok bool)
|
||||
}
|
||||
|
||||
// Migrator migrator interface
|
||||
type Migrator interface {
|
||||
// AutoMigrate
|
||||
@ -59,6 +72,7 @@ type Migrator interface {
|
||||
// Database
|
||||
CurrentDatabase() string
|
||||
FullDataTypeOf(*schema.Field) clause.Expr
|
||||
GetTypeAliases(databaseTypeName string) []string
|
||||
|
||||
// Tables
|
||||
CreateTable(dst ...interface{}) error
|
||||
@ -66,12 +80,15 @@ type Migrator interface {
|
||||
HasTable(dst interface{}) bool
|
||||
RenameTable(oldName, newName interface{}) error
|
||||
GetTables() (tableList []string, err error)
|
||||
TableType(dst interface{}) (TableType, error)
|
||||
|
||||
// Columns
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
@ -90,4 +107,5 @@ type Migrator interface {
|
||||
DropIndex(dst interface{}, name string) error
|
||||
HasIndex(dst interface{}, name string) bool
|
||||
RenameIndex(dst interface{}, oldName, newName string) error
|
||||
GetIndexes(dst interface{}) ([]Index, error)
|
||||
}
|
||||
|
43
migrator/index.go
Normal file
43
migrator/index.go
Normal file
@ -0,0 +1,43 @@
|
||||
package migrator
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Index implements gorm.Index interface
|
||||
type Index struct {
|
||||
TableName string
|
||||
NameValue string
|
||||
ColumnList []string
|
||||
PrimaryKeyValue sql.NullBool
|
||||
UniqueValue sql.NullBool
|
||||
OptionValue string
|
||||
}
|
||||
|
||||
// Table return the table name of the index.
|
||||
func (idx Index) Table() string {
|
||||
return idx.TableName
|
||||
}
|
||||
|
||||
// Name return the name of the index.
|
||||
func (idx Index) Name() string {
|
||||
return idx.NameValue
|
||||
}
|
||||
|
||||
// Columns return the columns of the index
|
||||
func (idx Index) Columns() []string {
|
||||
return idx.ColumnList
|
||||
}
|
||||
|
||||
// PrimaryKey returns the index is primary key or not.
|
||||
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
|
||||
}
|
||||
|
||||
// Unique returns whether the index is unique or not.
|
||||
func (idx Index) Unique() (unique bool, ok bool) {
|
||||
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// Option return the optional attribute of the index
|
||||
func (idx Index) Option() string {
|
||||
return idx.OptionValue
|
||||
}
|
@ -3,20 +3,32 @@ package migrator
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
|
||||
)
|
||||
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||
// with a possible trailing non-digit character (\D?).
|
||||
|
||||
// For example, values that can pass this regular expression are:
|
||||
// - "123"
|
||||
// - "abc456"
|
||||
// -"%$#@789"
|
||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
|
||||
// TODO:? Create const vars for raw sql queries ?
|
||||
|
||||
var _ gorm.Migrator = (*Migrator)(nil)
|
||||
|
||||
// Migrator m struct
|
||||
type Migrator struct {
|
||||
@ -30,6 +42,16 @@ type Config struct {
|
||||
gorm.Dialector
|
||||
}
|
||||
|
||||
type printSQLLogger struct {
|
||||
logger.Interface
|
||||
}
|
||||
|
||||
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
fmt.Println(sql + ";")
|
||||
l.Interface.Trace(ctx, begin, fc, err)
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||
@ -72,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
expr.SQL += " NOT NULL"
|
||||
}
|
||||
|
||||
if field.Unique {
|
||||
expr.SQL += " UNIQUE"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||
if field.DefaultValueInterface != nil {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
@ -89,23 +107,35 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
|
||||
queryTx = m.DB.Session(&gorm.Session{})
|
||||
execTx = queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
return queryTx, execTx
|
||||
}
|
||||
|
||||
// AutoMigrate auto migrate values
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
||||
queryTx, execTx := m.GetQueryAndExecTx()
|
||||
if !queryTx.Migrator().HasTable(value) {
|
||||
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
parseIndexes = stmt.Schema.ParseIndexes()
|
||||
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||
)
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
var foundColumn gorm.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
@ -117,37 +147,43 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// found, smartly migrate
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
for _, chk := range parseCheckConstraints {
|
||||
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range parseIndexes {
|
||||
if !queryTx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -174,7 +210,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
@ -185,7 +221,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if !field.IgnoreMigration {
|
||||
createTableSQL += "? ?"
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
|
||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||
createTableSQL += ","
|
||||
}
|
||||
@ -193,7 +229,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := []interface{}{}
|
||||
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
@ -204,8 +240,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
defer func(value interface{}, name string) {
|
||||
if errr == nil {
|
||||
errr = tx.Migrator().CreateIndex(value, name)
|
||||
if err == nil {
|
||||
err = tx.Migrator().CreateIndex(value, name)
|
||||
}
|
||||
}(value, idx.Name)
|
||||
} else {
|
||||
@ -223,15 +259,18 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
|
||||
createTableSQL += ","
|
||||
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := buildConstraint(constraint)
|
||||
sql, vars := constraint.Build()
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
@ -239,6 +278,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
|
||||
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
@ -252,8 +296,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
createTableSQL += fmt.Sprint(tableOption)
|
||||
}
|
||||
|
||||
errr = tx.Exec(createTableSQL, values...).Error
|
||||
return errr
|
||||
err = tx.Exec(createTableSQL, values...).Error
|
||||
return err
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -402,32 +446,58 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
||||
|
||||
// MigrateColumn migrate column
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
if field.IgnoreMigration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// found, smart migrate
|
||||
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
var (
|
||||
alterColumn bool
|
||||
isSameType = fullDataType == realDataType
|
||||
)
|
||||
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
isSameType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -439,19 +509,29 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
|
||||
// check unique
|
||||
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
|
||||
// not primary key
|
||||
if !field.PrimaryKey {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check default value
|
||||
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
|
||||
// not primary key
|
||||
if !field.PrimaryKey {
|
||||
if !field.PrimaryKey {
|
||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||
dv, dvNotNull := columnType.DefaultValue()
|
||||
if dvNotNull && !currentDefaultNotNull {
|
||||
// default value -> null
|
||||
alterColumn = true
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if currentDefaultNotNull || dvNotNull {
|
||||
switch field.GORMDataType {
|
||||
case schema.Time:
|
||||
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
|
||||
alterColumn = true
|
||||
}
|
||||
case schema.Bool:
|
||||
v1, _ := strconv.ParseBool(dv)
|
||||
v2, _ := strconv.ParseBool(field.DefaultValue)
|
||||
alterColumn = v1 != v2
|
||||
default:
|
||||
alterColumn = dv != field.DefaultValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -463,13 +543,39 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn && !field.IgnoreMigration {
|
||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||
if alterColumn {
|
||||
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
unique, ok := columnType.Unique()
|
||||
if !ok || field.PrimaryKey {
|
||||
return nil // skip primary key
|
||||
}
|
||||
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// We're currently only receiving boolean values on `Unique` tag,
|
||||
// so the UniqueConstraint name is fixed
|
||||
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
if unique && !field.Unique {
|
||||
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||
}
|
||||
if !unique && field.Unique {
|
||||
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
@ -499,47 +605,76 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
// CreateView create view
|
||||
// CreateView create view from Query in gorm.ViewOption.
|
||||
// Query in gorm.ViewOption is a [subquery]
|
||||
//
|
||||
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
|
||||
// q := DB.Model(&User{}).Where("age > ?", 20)
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
|
||||
//
|
||||
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
|
||||
// q := DB.Model(&User{})
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
|
||||
//
|
||||
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
|
||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||
return gorm.ErrNotImplemented
|
||||
if option.Query == nil {
|
||||
return gorm.ErrSubQueryRequired
|
||||
}
|
||||
|
||||
sql := new(strings.Builder)
|
||||
sql.WriteString("CREATE ")
|
||||
if option.Replace {
|
||||
sql.WriteString("OR REPLACE ")
|
||||
}
|
||||
sql.WriteString("VIEW ")
|
||||
m.QuoteTo(sql, name)
|
||||
sql.WriteString(" AS ")
|
||||
|
||||
m.DB.Statement.AddVar(sql, option.Query)
|
||||
|
||||
if option.CheckOption != "" {
|
||||
sql.WriteString(" ")
|
||||
sql.WriteString(option.CheckOption)
|
||||
}
|
||||
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
|
||||
}
|
||||
|
||||
// DropView drop view
|
||||
func (m Migrator) DropView(name string) error {
|
||||
return gorm.ErrNotImplemented
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
|
||||
}
|
||||
|
||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
||||
//
|
||||
// Deprecated: use GuessConstraintInterfaceAndTable instead.
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
switch c := constraint.(type) {
|
||||
case *schema.Constraint:
|
||||
return c, nil, table
|
||||
case *schema.CheckConstraint:
|
||||
return nil, c, table
|
||||
default:
|
||||
return nil, nil, table
|
||||
}
|
||||
}
|
||||
|
||||
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
|
||||
// nolint:cyclop
|
||||
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
|
||||
if stmt.Schema == nil {
|
||||
return nil, nil, stmt.Table
|
||||
return nil, stmt.Table
|
||||
}
|
||||
|
||||
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
||||
if chk, ok := checkConstraints[name]; ok {
|
||||
return nil, &chk, stmt.Table
|
||||
return &chk, stmt.Table
|
||||
}
|
||||
|
||||
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
|
||||
if uni, ok := uniqueConstraints[name]; ok {
|
||||
return &uni, stmt.Table
|
||||
}
|
||||
|
||||
getTable := func(rel *schema.Relationship) string {
|
||||
@ -554,7 +689,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
||||
return constraint, nil, getTable(rel)
|
||||
return constraint, getTable(rel)
|
||||
}
|
||||
}
|
||||
|
||||
@ -562,40 +697,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
|
||||
for k := range checkConstraints {
|
||||
if checkConstraints[k].Field == field {
|
||||
v := checkConstraints[k]
|
||||
return nil, &v, stmt.Table
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for k := range uniqueConstraints {
|
||||
if uniqueConstraints[k].Field == field {
|
||||
v := uniqueConstraints[k]
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
||||
return constraint, nil, getTable(rel)
|
||||
return constraint, getTable(rel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, stmt.Schema.Table
|
||||
return nil, stmt.Schema.Table
|
||||
}
|
||||
|
||||
// CreateConstraint create constraint
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if chk != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
|
||||
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
|
||||
).Error
|
||||
}
|
||||
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
vars := []interface{}{clause.Table{Name: table}}
|
||||
if stmt.TableExpr != nil {
|
||||
vars[0] = stmt.TableExpr
|
||||
}
|
||||
sql, values := buildConstraint(constraint)
|
||||
sql, values := constraint.Build()
|
||||
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -603,11 +737,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
// DropConstraint drop constraint
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
name = constraint.GetName()
|
||||
}
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
})
|
||||
@ -618,11 +750,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
name = constraint.GetName()
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@ -759,7 +889,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||
}
|
||||
beDependedOn := map[*schema.Schema]bool{}
|
||||
if err := dep.Parse(value); err != nil {
|
||||
// support for special table name
|
||||
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
|
||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||
}
|
||||
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||
@ -767,26 +898,31 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
||||
}
|
||||
parsedSchemas[dep.Statement.Schema] = true
|
||||
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
if !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -843,3 +979,18 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
||||
}
|
||||
return clause.Table{Name: stmt.Table}
|
||||
}
|
||||
|
||||
// GetIndexes return Indexes []gorm.Index and execErr error
|
||||
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
||||
// GetTypeAliases return database type aliases
|
||||
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableType return tableType gorm.TableType and execErr error
|
||||
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
33
migrator/table_type.go
Normal file
33
migrator/table_type.go
Normal file
@ -0,0 +1,33 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// TableType table type implements TableType interface
|
||||
type TableType struct {
|
||||
SchemaValue string
|
||||
NameValue string
|
||||
TypeValue string
|
||||
CommentValue sql.NullString
|
||||
}
|
||||
|
||||
// Schema returns the schema of the table.
|
||||
func (ct TableType) Schema() string {
|
||||
return ct.SchemaValue
|
||||
}
|
||||
|
||||
// Name returns the name of the table.
|
||||
func (ct TableType) Name() string {
|
||||
return ct.NameValue
|
||||
}
|
||||
|
||||
// Type returns the type of the table.
|
||||
func (ct TableType) Type() string {
|
||||
return ct.TypeValue
|
||||
}
|
||||
|
||||
// Comment returns the comment of current table.
|
||||
func (ct TableType) Comment() (comment string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
7
model.go
7
model.go
@ -4,9 +4,10 @@ import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
//
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
|
128
prepare_stmt.go
128
prepare_stmt.go
@ -3,30 +3,44 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts map[string]Stmt
|
||||
Stmts map[string]*Stmt
|
||||
PreparedSQL []string
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
@ -42,31 +56,72 @@ func (db *PreparedStmtDB) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
for _, stmt := range sdb.Stmts {
|
||||
go stmt.Close()
|
||||
}
|
||||
sdb.PreparedSQL = make([]string, 0, 100)
|
||||
sdb.Stmts = make(map[string]*Stmt)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, nil
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
return stmt, nil
|
||||
} else if ok {
|
||||
go stmt.Close()
|
||||
db.Mux.Unlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
|
||||
// cache preparing stmt first
|
||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||
db.Stmts[query] = &cacheStmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err == nil {
|
||||
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
db.Mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, err
|
||||
}
|
||||
|
||||
return db.Stmts[query], err
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
db.Mux.Unlock()
|
||||
|
||||
return cacheStmt, nil
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
@ -74,6 +129,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
@ -81,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
@ -95,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
@ -114,20 +182,32 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Ping() error {
|
||||
conn, err := db.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
@ -137,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
@ -152,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
@ -170,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Ping() error {
|
||||
conn, err := tx.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
130
scan.go
130
scan.go
@ -4,10 +4,10 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// prepareValues prepare values slice
|
||||
@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
@ -65,26 +65,49 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
joinedNestedSchemaMap := make(map[string]interface{})
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else {
|
||||
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
continue
|
||||
}
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else { // joinFields count is larger than 2 when using join
|
||||
var isNilPtrValue bool
|
||||
var relValue reflect.Value
|
||||
// does not contain raw dbname
|
||||
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||
// current reflect value
|
||||
currentReflectValue := reflectValue
|
||||
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||
for _, joinSchema := range nestedJoinSchemas {
|
||||
fullRels = append(fullRels, joinSchema.Name)
|
||||
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||
// same nested structure
|
||||
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
isNilPtrValue = true
|
||||
break
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedNestedSchemaMap[fullRelsName] = nil
|
||||
}
|
||||
}
|
||||
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
||||
currentReflectValue = relValue
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
field.NewValuePool.Put(values[idx])
|
||||
if !isNilPtrValue { // ignore if value is nil
|
||||
f := joinFields[idx][len(joinFields[idx])-1]
|
||||
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
field.NewValuePool.Put(values[idx])
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,11 +179,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
selectedColumnsMap = make(map[string]int, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
@ -193,29 +215,45 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
matchedFieldCount := make(map[string]int, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
if curIndex, ok := selectedColumnsMap[column]; ok {
|
||||
for fieldIndex, selectField := range sch.Fields[curIndex+1:] {
|
||||
fields[idx] = field
|
||||
if count, ok := matchedFieldCount[column]; ok {
|
||||
// handle duplicate fields
|
||||
for _, selectField := range sch.Fields {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
||||
fields[idx] = selectField
|
||||
break
|
||||
if count == 0 {
|
||||
matchedFieldCount[column]++
|
||||
fields[idx] = selectField
|
||||
break
|
||||
}
|
||||
count--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fields[idx] = field
|
||||
selectedColumnsMap[column] = idx
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||
relFields = append(relFields, rel.Field)
|
||||
for _, name := range names[1 : subNameCount-1] {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// lastest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
joinFields = make([][]*schema.Field, len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
relFields = append(relFields, field)
|
||||
joinFields[idx] = relFields
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -229,11 +267,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var elem reflect.Value
|
||||
var (
|
||||
elem reflect.Value
|
||||
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||
)
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
if isArrayKind {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
} else {
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
@ -260,10 +311,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if isPtr {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
if !isPtr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if isArrayKind {
|
||||
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||
}
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,35 +0,0 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
|
||||
|
||||
type Check struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]Check {
|
||||
checks := map[string]Check{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = Check{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
66
schema/constraint.go
Normal file
66
schema/constraint.go
Normal file
@ -0,0 +1,66 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||
|
||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||
checks := map[string]CheckConstraint{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
type UniqueConstraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||
|
||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||
}
|
||||
|
||||
// ParseUniqueConstraints parse schema unique constraints
|
||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||
uniques := make(map[string]UniqueConstraint)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Unique {
|
||||
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||
}
|
||||
}
|
||||
return uniques
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type UserCheck struct {
|
||||
@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
|
||||
t.Fatalf("failed to parse user check, got error %v", err)
|
||||
}
|
||||
|
||||
results := map[string]schema.Check{
|
||||
results := map[string]schema.CheckConstraint{
|
||||
"name_checker": {
|
||||
Name: "name_checker",
|
||||
Constraint: "name <> 'jinzhu'",
|
||||
@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUniqueConstraints(t *testing.T) {
|
||||
type UserUnique struct {
|
||||
Name1 string `gorm:"unique"`
|
||||
Name2 string `gorm:"uniqueIndex"`
|
||||
}
|
||||
|
||||
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||
}
|
||||
constraints := user.ParseUniqueConstraints()
|
||||
|
||||
results := map[string]schema.UniqueConstraint{
|
||||
"uni_user_uniques_name1": {
|
||||
Name: "uni_user_uniques_name1",
|
||||
Field: &schema.Field{Name: "Name1", Unique: true},
|
||||
},
|
||||
}
|
||||
for k, result := range results {
|
||||
v, ok := constraints[k]
|
||||
if !ok {
|
||||
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
|
||||
}
|
||||
tests.AssertObjEqual(t, result, v, "Name")
|
||||
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
|
||||
}
|
||||
}
|
102
schema/field.go
102
schema/field.go
@ -49,6 +49,8 @@ const (
|
||||
Bytes DataType = "bytes"
|
||||
)
|
||||
|
||||
const DefaultAutoIncrementIncrement int64 = 1
|
||||
|
||||
// Field is the representation of model schema's field
|
||||
type Field struct {
|
||||
Name string
|
||||
@ -87,6 +89,16 @@ type Field struct {
|
||||
Set func(context.Context, reflect.Value, interface{}) error
|
||||
Serializer SerializerInterface
|
||||
NewValuePool FieldNewValuePool
|
||||
|
||||
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
|
||||
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
|
||||
// It causes field unnecessarily migration.
|
||||
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
|
||||
UniqueIndex string
|
||||
}
|
||||
|
||||
func (field *Field) BindName() string {
|
||||
return strings.Join(field.BindNames, ".")
|
||||
}
|
||||
|
||||
// ParseField parses reflect.StructField to Field
|
||||
@ -115,7 +127,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||
Comment: tagSetting["COMMENT"],
|
||||
AutoIncrementIncrement: 1,
|
||||
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||
}
|
||||
|
||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||
@ -174,7 +186,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
field.DataType = String
|
||||
field.Serializer = v
|
||||
} else {
|
||||
var serializerName = field.TagSettings["JSON"]
|
||||
serializerName := field.TagSettings["JSON"]
|
||||
if serializerName == "" {
|
||||
serializerName = field.TagSettings["SERIALIZER"]
|
||||
}
|
||||
@ -403,18 +415,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
if ef.PrimaryKey {
|
||||
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
||||
ef.PrimaryKey = true
|
||||
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
||||
ef.PrimaryKey = true
|
||||
} else {
|
||||
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
|
||||
ef.PrimaryKey = false
|
||||
|
||||
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||
ef.AutoIncrement = false
|
||||
}
|
||||
|
||||
if ef.DefaultValue == "" {
|
||||
if !ef.AutoIncrement && ef.DefaultValue == "" {
|
||||
ef.HasDefaultValue = false
|
||||
}
|
||||
}
|
||||
@ -472,9 +480,6 @@ func (field *Field) setupValuerAndSetter() {
|
||||
oldValuerOf := field.ValueOf
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
value, zero := oldValuerOf(ctx, v)
|
||||
if zero {
|
||||
return value, zero
|
||||
}
|
||||
|
||||
s, ok := value.(SerializerValuerInterface)
|
||||
if !ok {
|
||||
@ -487,7 +492,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
Destination: v,
|
||||
Context: ctx,
|
||||
fieldValue: value,
|
||||
}, false
|
||||
}, zero
|
||||
}
|
||||
}
|
||||
|
||||
@ -607,6 +612,22 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||
}
|
||||
case **int:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||
case int:
|
||||
@ -643,7 +664,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -652,7 +673,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -671,6 +692,22 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||
}
|
||||
case **uint:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case uint64:
|
||||
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||
case uint:
|
||||
@ -701,7 +738,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
@ -723,6 +760,10 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||
}
|
||||
case **float32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
|
||||
}
|
||||
case float64:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||
case float32:
|
||||
@ -813,7 +854,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **time.Time:
|
||||
if data != nil {
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||
}
|
||||
case time.Time:
|
||||
@ -849,14 +890,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
@ -877,14 +916,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
@ -913,6 +950,8 @@ func (field *Field) setupValuerAndSetter() {
|
||||
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||
}
|
||||
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
if s, ok := v.(*serializer); ok {
|
||||
if s.fieldValue != nil {
|
||||
@ -920,11 +959,12 @@ func (field *Field) setupValuerAndSetter() {
|
||||
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||
if sameElemType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||
} else if sameType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||
}
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
s.Serializer = si.Interface().(SerializerInterface)
|
||||
}
|
||||
} else {
|
||||
err = oldFieldSetter(ctx, value, v)
|
||||
@ -936,11 +976,15 @@ func (field *Field) setupValuerAndSetter() {
|
||||
|
||||
func (field *Field) setupNewValuePool() {
|
||||
if field.Serializer != nil {
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.NewValuePool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
return &serializer{
|
||||
Field: field,
|
||||
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
|
||||
Serializer: si.Interface().(SerializerInterface),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ type Index struct {
|
||||
Type string // btree, hash, gist, spgist, gin, and brin
|
||||
Where string
|
||||
Comment string
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||
}
|
||||
|
||||
type IndexOption struct {
|
||||
@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.UniqueIndex = index.Name
|
||||
}
|
||||
}
|
||||
return indexes
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type UserIndex struct {
|
||||
@ -19,6 +19,7 @@ type UserIndex struct {
|
||||
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
||||
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
||||
Name7 string `gorm:"index:type"`
|
||||
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
|
||||
|
||||
// Composite Index: Flattened structure.
|
||||
Data0A string `gorm:"index:,composite:comp_id0"`
|
||||
@ -65,7 +66,7 @@ func TestParseIndex(t *testing.T) {
|
||||
"idx_name": {
|
||||
Name: "idx_name",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
|
||||
},
|
||||
"idx_user_indices_name3": {
|
||||
Name: "idx_user_indices_name3",
|
||||
@ -81,7 +82,7 @@ func TestParseIndex(t *testing.T) {
|
||||
"idx_user_indices_name4": {
|
||||
Name: "idx_user_indices_name4",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
|
||||
},
|
||||
"idx_user_indices_name5": {
|
||||
Name: "idx_user_indices_name5",
|
||||
@ -102,18 +103,27 @@ func TestParseIndex(t *testing.T) {
|
||||
},
|
||||
"idx_id": {
|
||||
Name: "idx_id",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
},
|
||||
"idx_oid": {
|
||||
Name: "idx_oid",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
},
|
||||
"type": {
|
||||
Name: "type",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
||||
},
|
||||
"idx_user_indices_name8": {
|
||||
Name: "idx_user_indices_name8",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "Name8"}, Length: 10},
|
||||
// Note: Duplicate Columns
|
||||
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
|
||||
},
|
||||
},
|
||||
"idx_user_indices_comp_id0": {
|
||||
Name: "idx_user_indices_comp_id0",
|
||||
Type: "",
|
||||
@ -146,37 +156,109 @@ func TestParseIndex(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
indices := user.ParseIndexes()
|
||||
CheckIndices(t, results, user.ParseIndexes())
|
||||
}
|
||||
|
||||
for k, result := range results {
|
||||
v, ok := indices[k]
|
||||
if !ok {
|
||||
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
|
||||
}
|
||||
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
type IndexTest struct {
|
||||
FieldA string `gorm:"unique;index"` // unique and index
|
||||
FieldB string `gorm:"unique"` // unique
|
||||
|
||||
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
|
||||
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
|
||||
t.Errorf(
|
||||
"index %v %v should equal, expects %v, got %v",
|
||||
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
|
||||
)
|
||||
FieldC string `gorm:"index:,unique"` // uniqueIndex
|
||||
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
|
||||
|
||||
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
|
||||
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
|
||||
|
||||
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
|
||||
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
|
||||
|
||||
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
|
||||
|
||||
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||
}
|
||||
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
indices := indexSchema.ParseIndexes()
|
||||
CheckIndices(t, map[string]schema.Index{
|
||||
"idx_index_tests_field_a": {
|
||||
Name: "idx_index_tests_field_a",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
|
||||
},
|
||||
"idx_index_tests_field_c": {
|
||||
Name: "idx_index_tests_field_c",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
|
||||
},
|
||||
"idx_index_tests_field_d": {
|
||||
Name: "idx_index_tests_field_d",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldD"}},
|
||||
// Note: Duplicate Columns
|
||||
{Field: &schema.Field{Name: "FieldD"}},
|
||||
},
|
||||
},
|
||||
"uniq_field_e1_e2": {
|
||||
Name: "uniq_field_e1_e2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldE1"}},
|
||||
{Field: &schema.Field{Name: "FieldE2"}},
|
||||
},
|
||||
},
|
||||
"idx_index_tests_field_f1": {
|
||||
Name: "idx_index_tests_field_f1",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
|
||||
},
|
||||
"uniq_field_f1_f2": {
|
||||
Name: "uniq_field_f1_f2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldF1"}},
|
||||
{Field: &schema.Field{Name: "FieldF2"}},
|
||||
},
|
||||
},
|
||||
"idx_index_tests_field_g": {
|
||||
Name: "idx_index_tests_field_g",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
|
||||
},
|
||||
"uniq_field_h1_h2": {
|
||||
Name: "uniq_field_h1_h2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
|
||||
{Field: &schema.Field{Name: "FieldH2"}},
|
||||
},
|
||||
},
|
||||
}, indices)
|
||||
}
|
||||
|
||||
func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) {
|
||||
for k, ei := range expected {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
ai, ok := actual[k]
|
||||
if !ok {
|
||||
t.Errorf("expected index %q but actual missing", k)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for idx, ef := range result.Fields {
|
||||
rf := v.Fields[idx]
|
||||
if rf.Field.Name != ef.Field.Name {
|
||||
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
|
||||
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
|
||||
if len(ei.Fields) != len(ai.Fields) {
|
||||
t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields))
|
||||
return
|
||||
}
|
||||
|
||||
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
|
||||
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
|
||||
t.Errorf(
|
||||
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
|
||||
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
|
||||
)
|
||||
}
|
||||
for i, ef := range ei.Fields {
|
||||
af := ai.Fields[i]
|
||||
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length")
|
||||
}
|
||||
}
|
||||
})
|
||||
delete(actual, k)
|
||||
}
|
||||
for k := range actual {
|
||||
t.Errorf("unexpected index %q", k)
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,12 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConstraintInterface database constraint interface
|
||||
type ConstraintInterface interface {
|
||||
GetName() string
|
||||
Build() (sql string, vars []interface{})
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
|
@ -19,6 +19,7 @@ type Namer interface {
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
UniqueName(table, column string) string
|
||||
}
|
||||
|
||||
// Replacer replacer interface like strings.Replacer
|
||||
@ -26,12 +27,15 @@ type Replacer interface {
|
||||
Replace(name string) string
|
||||
}
|
||||
|
||||
var _ Namer = (*NamingStrategy)(nil)
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
IdentifierMaxLength int
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
@ -84,17 +88,26 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
return ns.formatName("idx", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
// UniqueName generate unique constraint name
|
||||
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||
return ns.formatName("uni", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_")
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > 64 {
|
||||
if ns.IdentifierMaxLength == 0 {
|
||||
ns.IdentifierMaxLength = 64
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
||||
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
|
@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
|
||||
ns := NamingStrategy{IdentifierMaxLength: 63}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
|
||||
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
||||
ns := NamingStrategy{}
|
||||
ns := NamingStrategy{IdentifierMaxLength: 64}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||
|
@ -27,6 +27,8 @@ type Relationships struct {
|
||||
HasMany []*Relationship
|
||||
Many2Many []*Relationship
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
@ -74,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return nil
|
||||
}
|
||||
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
@ -87,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +109,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
schema.setRelation(relation)
|
||||
switch relation.Type {
|
||||
case HasOne:
|
||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||
@ -122,34 +125,100 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return relation
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
|
||||
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
if len(rel.Field.BindNames) > 1 {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
} else {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.BindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.BindNames {
|
||||
if i < len(relation.Field.BindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||
}
|
||||
relationships = relationships.EmbeddedRelations[name]
|
||||
} else {
|
||||
if relationships.Relations == nil {
|
||||
relationships.Relations = map[string]*Relationship{}
|
||||
}
|
||||
relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
//
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
@ -161,10 +230,17 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryKeyField.DataType) {
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
@ -191,7 +267,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
err error
|
||||
joinTableFields []reflect.StructField
|
||||
fieldsMap = map[string]*Field{}
|
||||
ownFieldsMap = map[string]bool{} // fix self join many2many
|
||||
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||
referFieldsMap = map[string]*Field{}
|
||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||
)
|
||||
@ -229,21 +306,19 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = true
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
fieldsMap[joinFieldName] = ownField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: ownField.StructField.PkgPath,
|
||||
Type: ownField.StructField.Type,
|
||||
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
@ -253,13 +328,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
}
|
||||
|
||||
fieldsMap[joinFieldName] = relField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: relField.StructField.PkgPath,
|
||||
Type: relField.StructField.Type,
|
||||
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
|
||||
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||
fieldsMap[joinFieldName] = relField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: relField.StructField.PkgPath,
|
||||
Type: relField.StructField.Type,
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
@ -268,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
@ -315,31 +400,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
f.Size = fieldsMap[f.Name].Size
|
||||
}
|
||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
||||
|
||||
if ownPrimaryField {
|
||||
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||
joinRel.Field = relation.Field
|
||||
joinRel.References = append(joinRel.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
})
|
||||
} else {
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||
if joinRefRel.Field == nil {
|
||||
joinRefRel.Field = relation.Field
|
||||
}
|
||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: fieldsMap[f.Name],
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: ownPrimaryField,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -381,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -389,34 +481,31 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
} else {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
var primaryFields []*Field
|
||||
var primarySchemaName = primarySchema.Name
|
||||
primarySchemaName := primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
}
|
||||
@ -431,6 +520,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
primaryFields = primarySchema.PrimaryFields
|
||||
}
|
||||
|
||||
primaryFieldLoop:
|
||||
for _, primaryField := range primaryFields {
|
||||
lookUpName := primarySchemaName + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
@ -439,23 +529,33 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpField(name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
break
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(foreignFields) == 0 {
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
reguessOrErr()
|
||||
return
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
case len(relation.primaryKeys) > 0:
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
@ -469,7 +569,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if len(primaryFields) == 0 {
|
||||
case len(primaryFields) == 0:
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
@ -505,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is ForeignKey Constraint
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
@ -516,6 +617,31 @@ type Constraint struct {
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||
|
||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
references := make([]interface{}, 0, len(constraint.References))
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||
if str == "-" {
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
|
||||
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
||||
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
||||
t.Errorf("Failed to parse schema")
|
||||
t.Errorf("Failed to parse schema, got error %v", err)
|
||||
} else {
|
||||
for _, rel := range relations {
|
||||
checkSchemaRelation(t, s, rel)
|
||||
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMany2ManySharedForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Kind string
|
||||
ProfileRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
|
||||
Kind string
|
||||
Refer uint
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
||||
References: []Reference{
|
||||
{"Refer", "User", "UserRefer", "user_profiles", "", true},
|
||||
{"Kind", "User", "Kind", "user_profiles", "", true},
|
||||
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
|
||||
{"Kind", "Profile", "Kind", "user_profiles", "", false},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
@ -491,6 +518,319 @@ func TestEmbeddedRelation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedHas(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
type User struct {
|
||||
ID int
|
||||
Cat struct {
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||
Dog struct {
|
||||
ID int
|
||||
Name string
|
||||
UserID int
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolymorphic(t *testing.T) {
|
||||
t.Run("has one", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with only polymorphic type", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
type Address struct {
|
||||
CountryID int
|
||||
Country Country
|
||||
}
|
||||
type NestedAddress struct {
|
||||
Address
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||
AddressID int
|
||||
Address struct {
|
||||
ID int
|
||||
Address
|
||||
}
|
||||
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"PostalAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"VisitingAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"NestedAddress": {
|
||||
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||
"Address": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestVariableRelation(t *testing.T) {
|
||||
var result struct {
|
||||
User
|
||||
@ -615,7 +955,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
||||
s, err := schema.Parse(
|
||||
&Book{},
|
||||
&sync.Map{},
|
||||
schema.NamingStrategy{},
|
||||
schema.NamingStrategy{IdentifierMaxLength: 64},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema")
|
||||
|
138
schema/schema.go
138
schema/schema.go
@ -6,12 +6,27 @@ import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type callbackType string
|
||||
|
||||
const (
|
||||
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||
callbackTypeAfterSave callbackType = "AfterSave"
|
||||
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||
callbackTypeAfterFind callbackType = "AfterFind"
|
||||
)
|
||||
|
||||
// ErrUnsupportedDataType unsupported data type
|
||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||
|
||||
@ -25,6 +40,7 @@ type Schema struct {
|
||||
PrimaryFieldDBNames []string
|
||||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
@ -67,10 +83,35 @@ func (schema Schema) LookUpField(name string) *Field {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||
//
|
||||
// type Struct struct {
|
||||
// Embedded struct {
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||
// }
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||
// }
|
||||
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||
if len(bindNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
@ -112,7 +153,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
schemaCacheKey = modelType
|
||||
}
|
||||
|
||||
// Load exist schmema cache, return if exists
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
@ -125,6 +166,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
@ -133,20 +177,21 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
|
||||
// Load exist schmema cache, return if exists
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
@ -169,6 +214,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
bindName := field.BindName()
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
@ -177,6 +223,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
schema.FieldsByDBName[field.DBName] = field
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
@ -195,6 +242,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
}
|
||||
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter()
|
||||
}
|
||||
@ -214,8 +264,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
if schema.PrioritizedPrimaryField == nil {
|
||||
if len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
} else if len(schema.PrimaryFields) > 1 {
|
||||
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||
for _, field := range schema.PrimaryFields {
|
||||
if field.AutoIncrement {
|
||||
schema.PrioritizedPrimaryField = field
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range schema.PrimaryFields {
|
||||
@ -223,7 +283,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
@ -242,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
}
|
||||
|
||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||
for _, name := range callbacks {
|
||||
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
|
||||
callbackTypes := []callbackType{
|
||||
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||
callbackTypeAfterFind,
|
||||
}
|
||||
for _, cbName := range callbackTypes {
|
||||
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -276,6 +342,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
return schema, schema.err
|
||||
} else {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[field.BindName()] = field
|
||||
}
|
||||
}
|
||||
|
||||
@ -302,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
return schema, schema.err
|
||||
}
|
||||
|
||||
// This unrolling is needed to show to the compiler the exact set of methods
|
||||
// that can be used on the modelType.
|
||||
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||
// abandon dead code elimination for the entire binary.
|
||||
// As of go1.22 the compiler supports one special case of a string constant
|
||||
// being passed to MethodByName. For enterprise customers or those building
|
||||
// large binaries, this gives a significant reduction in binary size.
|
||||
// https://github.com/golang/go/issues/62257
|
||||
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||
switch cbType {
|
||||
case callbackTypeBeforeCreate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||
case callbackTypeAfterCreate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||
case callbackTypeBeforeUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||
case callbackTypeAfterUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||
case callbackTypeBeforeSave:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||
case callbackTypeAfterSave:
|
||||
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||
case callbackTypeBeforeDelete:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||
case callbackTypeAfterDelete:
|
||||
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||
case callbackTypeAfterFind:
|
||||
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||
default:
|
||||
return reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
|
@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
|
||||
}
|
||||
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &f, nil)
|
||||
for i := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||
})
|
||||
}
|
||||
|
||||
type EmbeddedRelations struct {
|
||||
Relations map[string]Relation
|
||||
EmbeddedRelations map[string]EmbeddedRelations
|
||||
}
|
||||
|
||||
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||
for name, relations := range actual {
|
||||
rs := expected[name]
|
||||
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||
if len(relations.Relations) != len(rs.Relations) {
|
||||
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||
}
|
||||
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||
}
|
||||
for n, rel := range relations.Relations {
|
||||
if r, ok := rs.Relations[n]; !ok {
|
||||
t.Errorf("failed to find relation by name %s", n)
|
||||
} else {
|
||||
checkSchemaRelation(t, &schema.Schema{
|
||||
Relationships: schema.Relationships{
|
||||
Relations: map[string]*schema.Relationship{n: rel},
|
||||
},
|
||||
}, r)
|
||||
}
|
||||
}
|
||||
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||
for k, v := range values {
|
||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||
|
@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
|
||||
type Product struct {
|
||||
ProductID uint `gorm:"primaryKey;autoIncrement"`
|
||||
LanguageCode uint `gorm:"primaryKey"`
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
type ProductNonAutoIncrement struct {
|
||||
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
|
||||
LanguageCode uint `gorm:"primaryKey"`
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
|
||||
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.Field{
|
||||
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
|
||||
}
|
||||
|
||||
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
|
||||
|
||||
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
})
|
||||
|
||||
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
|
||||
}
|
||||
|
||||
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
|
||||
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
|
||||
}
|
||||
}
|
||||
|
@ -70,8 +70,7 @@ type SerializerValuerInterface interface {
|
||||
}
|
||||
|
||||
// JSONSerializer json serializer
|
||||
type JSONSerializer struct {
|
||||
}
|
||||
type JSONSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -88,7 +87,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||
if len(bytes) > 0 {
|
||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
@ -98,12 +99,17 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
// UnixSecondSerializer json serializer
|
||||
type UnixSecondSerializer struct {
|
||||
}
|
||||
type UnixSecondSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -117,9 +123,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
|
||||
|
||||
// Value implements serializer interface
|
||||
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
@ -127,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
}
|
||||
|
||||
// GobSerializer gob serializer
|
||||
type GobSerializer struct {
|
||||
}
|
||||
type GobSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -142,8 +153,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||
}
|
||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||
err = decoder.Decode(fieldValue.Interface())
|
||||
if len(bytesValue) > 0 {
|
||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||
err = decoder.Decode(fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
|
@ -2,6 +2,7 @@ package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
|
||||
return tag
|
||||
}
|
||||
|
||||
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
|
||||
t := tag.Get("gorm")
|
||||
if strings.Contains(t, value) {
|
||||
return tag
|
||||
}
|
||||
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
|
||||
}
|
||||
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
@ -106,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
||||
notZero, zero bool
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Ptr ||
|
||||
reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
@ -124,7 +138,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
elem := reflectValue.Index(i)
|
||||
elemKey := elem.Interface()
|
||||
if elem.Kind() != reflect.Ptr {
|
||||
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
|
||||
elemKey = elem.Addr().Interface()
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||
if _, err := now.Parse(v); err == nil {
|
||||
return sql.NullString{String: v, Valid: true}
|
||||
}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
Field *schema.Field
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Name() string {
|
||||
@ -78,18 +89,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteUpdateClause struct {
|
||||
Field *schema.Field
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Name() string {
|
||||
@ -109,11 +121,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
Field *schema.Field
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Name() string {
|
||||
|
101
statement.go
101
statement.go
@ -49,9 +49,12 @@ type Statement struct {
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
@ -117,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
write(v.Raw, stmt.Schema.DBNames[0])
|
||||
} else {
|
||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
@ -179,6 +184,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
@ -304,6 +313,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for idx, arg := range args {
|
||||
if arg == nil {
|
||||
continue
|
||||
}
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
@ -312,9 +324,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
for _, scope := range v.Statement.scopes {
|
||||
v = scope(v)
|
||||
}
|
||||
v.executeScopes()
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
@ -437,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return conds
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -447,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
}
|
||||
|
||||
return conds
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
@ -540,8 +554,9 @@ func (stmt *Statement) clone() *Statement {
|
||||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
@ -650,54 +665,62 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn := func(column string, result bool) {
|
||||
if stmt.Schema == nil {
|
||||
results[column] = true
|
||||
results[column] = result
|
||||
} else if column == "*" {
|
||||
notRestricted = true
|
||||
notRestricted = result
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = true
|
||||
results[dbName] = result
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = true
|
||||
results[rel.Name] = result
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = true
|
||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
|
||||
results[matches[1]] = true
|
||||
results[field.DBName] = result
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[col] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = true
|
||||
results[column] = result
|
||||
}
|
||||
}
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn(column, true)
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, omit := range stmt.Omits {
|
||||
if stmt.Schema == nil {
|
||||
results[omit] = false
|
||||
} else if omit == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = false
|
||||
}
|
||||
} else if omit == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = false
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
||||
results[matches[1]] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
}
|
||||
for _, column := range stmt.Omits {
|
||||
processColumn(column, false)
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
|
@ -35,15 +35,36 @@ func TestWhereCloneCorruption(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilCondition(t *testing.T) {
|
||||
s := new(Statement)
|
||||
if len(s.BuildCondition(nil)) != 0 {
|
||||
t.Errorf("Nil condition should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameMatcher(t *testing.T) {
|
||||
for k, v := range map[string]string{
|
||||
"table.name": "name",
|
||||
"`table`.`name`": "name",
|
||||
"'table'.'name'": "name",
|
||||
"'table'.name": "name",
|
||||
for k, v := range map[string][]string{
|
||||
"table.name": {"table", "name"},
|
||||
"`table`.`name`": {"table", "name"},
|
||||
"'table'.'name'": {"table", "name"},
|
||||
"'table'.name": {"table", "name"},
|
||||
"table1.name_23": {"table1", "name_23"},
|
||||
"`table_1`.`name23`": {"table_1", "name23"},
|
||||
"'table23'.'name_1'": {"table23", "name_1"},
|
||||
"'table23'.name1": {"table23", "name1"},
|
||||
"'name1'": {"", "name1"},
|
||||
"`name_1`": {"", "name_1"},
|
||||
"`Name_1`": {"", "Name_1"},
|
||||
"`Table`.`nAme`": {"Table", "nAme"},
|
||||
"my_table.*": {"my_table", "*"},
|
||||
"`my_table`.*": {"my_table", "*"},
|
||||
"User__Company.*": {"User__Company", "*"},
|
||||
"`User__Company`.*": {"User__Company", "*"},
|
||||
`"User__Company".*`: {"User__Company", "*"},
|
||||
`"table"."*"`: {"", ""},
|
||||
} {
|
||||
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
||||
if table, column := matchName(k); table != v[0] || column != v[1] {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) {
|
||||
unexistCompanyID := company.ID + 9999999
|
||||
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
|
||||
if err := DB.Create(&user).Error; err == nil {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
t.Errorf("should have gotten foreign key violation error")
|
||||
}
|
||||
}
|
||||
@ -224,3 +226,81 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
|
||||
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
||||
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
||||
}
|
||||
|
||||
func TestBelongsToDefaultValue(t *testing.T) {
|
||||
type Org struct {
|
||||
ID string
|
||||
}
|
||||
type BelongsToUser struct {
|
||||
OrgID string
|
||||
Org Org `gorm:"default:NULL"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
|
||||
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
|
||||
|
||||
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
|
||||
tx.AutoMigrate(&BelongsToUser{}, &Org{})
|
||||
|
||||
user := &BelongsToUser{
|
||||
Org: Org{
|
||||
ID: "BelongsToUser_Org_1",
|
||||
},
|
||||
}
|
||||
err := DB.Create(&user).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
func TestBelongsToAssociationUnscoped(t *testing.T) {
|
||||
type ItemParent struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
}
|
||||
type ItemChild struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"type:varchar(50)"`
|
||||
ItemParentID uint
|
||||
ItemParent ItemParent
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
|
||||
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
|
||||
|
||||
item := ItemChild{
|
||||
Name: "name",
|
||||
ItemParent: ItemParent{
|
||||
Logo: "logo",
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
// test replace
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
||||
Logo: "updated logo",
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item parent, got error: %v", err)
|
||||
}
|
||||
|
||||
var parents []ItemParent
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 1 {
|
||||
t.Errorf("expected %d parents, got %d", 1, len(parents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
|
||||
t.Errorf("failed to delete item parent, got error: %v", err)
|
||||
}
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 0 {
|
||||
t.Errorf("expected %d parents, got %d", 0, len(parents))
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -421,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
|
||||
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
users := []User{
|
||||
*GetUser("slice-hasmany-1", Config{Toys: 2}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
|
||||
*GetUser("slice-hasmany-3", Config{Toys: 4}),
|
||||
}
|
||||
|
||||
@ -429,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
|
||||
// Count
|
||||
AssertAssociationCount(t, users, "Toys", 6, "")
|
||||
AssertAssociationCount(t, users, "Tools", 2, "")
|
||||
|
||||
// Find
|
||||
var toys []Toy
|
||||
@ -436,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
|
||||
}
|
||||
|
||||
// Find Tools (polymorphic with custom type and id)
|
||||
var tools []Tools
|
||||
DB.Model(&users).Association("Tools").Find(&tools)
|
||||
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
|
||||
}
|
||||
|
||||
// Append
|
||||
DB.Model(&users).Association("Toys").Append(
|
||||
&Toy{Name: "toy-slice-append-1"},
|
||||
@ -471,3 +481,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
DB.Model(&users).Association("Toys").Clear()
|
||||
AssertAssociationCount(t, users, "Toys", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestHasManyAssociationUnscoped(t *testing.T) {
|
||||
type ItemContent struct {
|
||||
gorm.Model
|
||||
ItemID uint `gorm:"not null"`
|
||||
Name string `gorm:"not null;type:varchar(50)"`
|
||||
LanguageCode string `gorm:"not null;type:varchar(2)"`
|
||||
}
|
||||
type Item struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
Contents []ItemContent `gorm:"foreignKey:ItemID"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemContent{}, &Item{})
|
||||
tx.AutoMigrate(&ItemContent{}, &Item{})
|
||||
|
||||
item := Item{
|
||||
Logo: "logo",
|
||||
Contents: []ItemContent{
|
||||
{Name: "name", LanguageCode: "en"},
|
||||
{Name: "ar name", LanguageCode: "ar"},
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
// test Replace
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
|
||||
{Name: "updated name", LanguageCode: "en"},
|
||||
{Name: "ar updated name", LanguageCode: "ar"},
|
||||
{Name: "le nom", LanguageCode: "fr"},
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item content, got error: %v", err)
|
||||
}
|
||||
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, count)
|
||||
}
|
||||
|
||||
var contents []ItemContent
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, len(contents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
|
||||
t.Errorf("failed to delete Contents, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
|
||||
t.Errorf("expected %d contents, got %d", 2, count)
|
||||
}
|
||||
|
||||
// test clear
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
|
||||
t.Errorf("failed to clear contents association, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, count)
|
||||
}
|
||||
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,12 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -94,6 +98,8 @@ func TestMany2ManyAssociation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMany2ManyOmitAssociations(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
|
||||
|
||||
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
|
||||
@ -324,3 +330,96 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
|
||||
DB.Model(&users).Association("Team").Clear()
|
||||
AssertAssociationCount(t, users, "Team", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestDuplicateMany2ManyAssociation(t *testing.T) {
|
||||
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
|
||||
}}
|
||||
|
||||
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
|
||||
}}
|
||||
users := []*User{&user1, &user2}
|
||||
var err error
|
||||
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
var findUser1 User
|
||||
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user1, findUser1)
|
||||
|
||||
var findUser2 User
|
||||
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user2, findUser2)
|
||||
}
|
||||
|
||||
func TestConcurrentMany2ManyAssociation(t *testing.T) {
|
||||
db, err := OpenTestConnection(&gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open test connection failed, err: %+v", err)
|
||||
}
|
||||
|
||||
count := 3
|
||||
|
||||
var languages []Language
|
||||
for i := 0; i < count; i++ {
|
||||
language := Language{Code: fmt.Sprintf("consurrent %d", i)}
|
||||
db.Create(&language)
|
||||
languages = append(languages, language)
|
||||
}
|
||||
|
||||
user := User{}
|
||||
db.Create(&user)
|
||||
db.Preload("Languages").FirstOrCreate(&user)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func(user User, language Language) {
|
||||
err := db.Model(&user).Association("Languages").Append(&language)
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
wg.Done()
|
||||
}(user, languages[i])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var find User
|
||||
err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error
|
||||
AssertEqual(t, err, nil)
|
||||
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
|
||||
}
|
||||
|
||||
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
|
||||
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
|
||||
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
|
||||
ID: 1,
|
||||
Name: "Test-company-1",
|
||||
}},
|
||||
}}
|
||||
|
||||
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
|
||||
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
|
||||
ID: 1,
|
||||
Name: "Test-company-1",
|
||||
}},
|
||||
}}
|
||||
users := []*User{&user1, &user2}
|
||||
var err error
|
||||
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
var findUser1 User
|
||||
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user1, findUser1)
|
||||
|
||||
var findUser2 User
|
||||
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user2, findUser2)
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -69,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestForeignKeyConstraints(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
@ -124,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
@ -284,3 +290,107 @@ func TestAssociationError(t *testing.T) {
|
||||
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
||||
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||
}
|
||||
|
||||
type (
|
||||
myType string
|
||||
emptyQueryClause struct {
|
||||
Field *schema.Field
|
||||
}
|
||||
)
|
||||
|
||||
func (myType) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{emptyQueryClause{Field: f}}
|
||||
}
|
||||
|
||||
func (sd emptyQueryClause) Name() string {
|
||||
return "empty"
|
||||
}
|
||||
|
||||
func (sd emptyQueryClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd emptyQueryClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||
type Organization struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
type Region struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Organizations []Organization `gorm:"many2many:region_orgs;"`
|
||||
}
|
||||
type RegionOrg struct {
|
||||
RegionId uint
|
||||
OrganizationId uint
|
||||
Empty myType
|
||||
}
|
||||
if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil {
|
||||
t.Fatalf("Failed to set up join table, got error: %s", err)
|
||||
}
|
||||
if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil {
|
||||
t.Fatalf("Failed to migrate, got error: %s", err)
|
||||
}
|
||||
if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil {
|
||||
t.Fatalf("Failed to migrate, got error: %v", err)
|
||||
}
|
||||
region := &Region{Name: "Region1"}
|
||||
if err := DB.Create(region).Error; err != nil {
|
||||
t.Fatalf("fail to create region %v", err)
|
||||
}
|
||||
var orgs []Organization
|
||||
|
||||
if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil {
|
||||
t.Fatalf("fail to find region organizations %v", err)
|
||||
} else {
|
||||
AssertEqual(t, len(orgs), 0)
|
||||
}
|
||||
}
|
||||
|
||||
type AssociationEmptyUser struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pets []AssociationEmptyPet
|
||||
}
|
||||
|
||||
type AssociationEmptyPet struct {
|
||||
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||
}
|
||||
|
||||
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
|
||||
id := uint(100)
|
||||
user := AssociationEmptyUser{
|
||||
ID: id,
|
||||
Name: "jinzhu",
|
||||
Pets: []AssociationEmptyPet{
|
||||
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create, got error: %v", err)
|
||||
}
|
||||
|
||||
var result AssociationEmptyUser
|
||||
err = DB.Preload("Pets").First(&result, &id).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScan(b *testing.B) {
|
||||
user := *GetUser("scan", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
var u User
|
||||
b.ResetTimer()
|
||||
for x := 0; x < b.N; x++ {
|
||||
DB.Raw("select * from users where id = ?", user.ID).Scan(&u)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScanSlice(b *testing.B) {
|
||||
DB.Exec("delete from users")
|
||||
for i := 0; i < 10_000; i++ {
|
||||
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
|
||||
DB.Create(&user)
|
||||
}
|
||||
|
||||
var u []User
|
||||
b.ResetTimer()
|
||||
for x := 0; x < b.N; x++ {
|
||||
DB.Raw("select * from users").Scan(&u)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScanSlicePointer(b *testing.B) {
|
||||
DB.Exec("delete from users")
|
||||
for i := 0; i < 10_000; i++ {
|
||||
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
|
||||
DB.Create(&user)
|
||||
}
|
||||
|
||||
var u []*User
|
||||
b.ResetTimer()
|
||||
for x := 0; x < b.N; x++ {
|
||||
DB.Raw("select * from users").Scan(&u)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdate(b *testing.B) {
|
||||
user := *GetUser("find", Config{})
|
||||
DB.Create(&user)
|
||||
|
@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
|
||||
func c3(*gorm.DB) {}
|
||||
func c4(*gorm.DB) {}
|
||||
func c5(*gorm.DB) {}
|
||||
func c6(*gorm.DB) {}
|
||||
|
||||
func TestCallbacks(t *testing.T) {
|
||||
type callback struct {
|
||||
@ -90,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||
results: []string{"c1", "c5", "c3", "c4"},
|
||||
results: []string{"c1", "c3", "c4", "c5"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
@ -112,6 +113,9 @@ func TestCallbacks(t *testing.T) {
|
||||
|
||||
for idx, data := range datas {
|
||||
db, err := gorm.Open(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
callbacks := db.Callback()
|
||||
|
||||
for _, c := range data.callbacks {
|
||||
@ -168,3 +172,83 @@ func TestCallbacks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluginCallbacks(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("plugin_1_fn1", c1)
|
||||
createCallback.After("*").Register("plugin_1_fn2", c2)
|
||||
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
// plugin 2
|
||||
createCallback.Before("*").Register("plugin_2_fn1", c3)
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.After("*").Register("plugin_2_fn2", c4)
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
// plugin 3
|
||||
createCallback.Before("*").Register("plugin_3_fn1", c5)
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.After("*").Register("plugin_3_fn2", c6)
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksGet(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||
}
|
||||
|
||||
createCallback.Remove("c1")
|
||||
if cb := createCallback.Get("c2"); cb != nil {
|
||||
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksRemove(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
createCallback.After("*").Register("c2", c2)
|
||||
createCallback.Before("c4").Register("c3", c3)
|
||||
createCallback.After("c2").Register("c4", c4)
|
||||
|
||||
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||
createCallback.Remove("c1")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c4")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c2")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c3")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error {
|
||||
}
|
||||
|
||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// return c.db.BeginTx(ctx, opts)
|
||||
// }
|
||||
//
|
||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// return c.db.BeginTx(ctx, opts)
|
||||
// }
|
||||
//
|
||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||
tx, err := c.db.BeginTx(ctx, opts)
|
||||
@ -100,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
expect: []string{
|
||||
"SELECT VERSION()",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
},
|
||||
}
|
||||
|
||||
@ -116,7 +118,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
@ -11,6 +11,32 @@ import (
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestCountWithGroup(t *testing.T) {
|
||||
DB.Create([]Company{
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_b"},
|
||||
{Name: "company_count_group_c"},
|
||||
})
|
||||
|
||||
var count1 int64
|
||||
if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil {
|
||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
if count1 != 1 {
|
||||
t.Errorf("Count with group should be 1, but got count: %v", count1)
|
||||
}
|
||||
|
||||
var count2 int64
|
||||
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
|
||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
if count2 != 2 {
|
||||
t.Errorf("Count with group should be 2, but got count: %v", count2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
var (
|
||||
user1 = *GetUser("count-1", Config{})
|
||||
@ -142,7 +168,7 @@ func TestCount(t *testing.T) {
|
||||
DB.Create(sameUsers)
|
||||
|
||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
||||
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
||||
t.Fatalf("Count should be 1, but got count: %v err %v", count11, err)
|
||||
}
|
||||
|
||||
var count12 int64
|
||||
|
@ -2,6 +2,7 @@ package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
@ -476,6 +477,13 @@ func TestOmitWithCreate(t *testing.T) {
|
||||
CheckUser(t, result2, user2)
|
||||
}
|
||||
|
||||
func TestFirstOrCreateNotExistsTable(t *testing.T) {
|
||||
company := Company{Name: "first_or_create_if_not_exists_table"}
|
||||
if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil {
|
||||
t.Errorf("not exists table, but err is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstOrCreateWithPrimaryKey(t *testing.T) {
|
||||
company := Company{ID: 100, Name: "company100_with_primarykey"}
|
||||
DB.FirstOrCreate(&company)
|
||||
@ -540,3 +548,246 @@ func TestFirstOrCreateRowsAffected(t *testing.T) {
|
||||
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
|
||||
type CompositeKeyProduct struct {
|
||||
ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key
|
||||
LanguageCode int `gorm:"primaryKey;"` // primary key
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
|
||||
if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil {
|
||||
t.Fatalf("failed to migrate, got error %v", err)
|
||||
}
|
||||
if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil {
|
||||
t.Fatalf("failed to migrate, got error %v", err)
|
||||
}
|
||||
|
||||
prod := &CompositeKeyProduct{
|
||||
LanguageCode: 56,
|
||||
Code: "Code56",
|
||||
Name: "ProductName56",
|
||||
}
|
||||
if err := DB.Create(&prod).Error; err != nil {
|
||||
t.Fatalf("failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
newProd := &CompositeKeyProduct{}
|
||||
if err := DB.First(&newProd).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
|
||||
type OnConflictUser struct {
|
||||
ID string
|
||||
Name string `gorm:"default:null"`
|
||||
Email string
|
||||
Mobile string `gorm:"default:'133xxxx'"`
|
||||
}
|
||||
|
||||
err := DB.Migrator().DropTable(&OnConflictUser{})
|
||||
AssertEqual(t, err, nil)
|
||||
err = DB.AutoMigrate(&OnConflictUser{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
u := OnConflictUser{
|
||||
ID: "on-conflict-user-id",
|
||||
Name: "on-conflict-user-name",
|
||||
Email: "on-conflict-user-email",
|
||||
Mobile: "on-conflict-user-mobile",
|
||||
}
|
||||
err = DB.Create(&u).Error
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
u.Name = "on-conflict-user-name-2"
|
||||
u.Email = "on-conflict-user-email-2"
|
||||
u.Mobile = ""
|
||||
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
var u2 OnConflictUser
|
||||
err = DB.Where("id = ?", u.ID).First(&u2).Error
|
||||
AssertEqual(t, err, nil)
|
||||
AssertEqual(t, u2.Name, "on-conflict-user-name-2")
|
||||
AssertEqual(t, u2.Email, "on-conflict-user-email-2")
|
||||
AssertEqual(t, u2.Mobile, "133xxxx")
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||
if !isMysql() {
|
||||
t.Skipf("This test case skipped, because of only supporting for mysql")
|
||||
}
|
||||
|
||||
// case 1: one record, create from map[string]interface{}
|
||||
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue1["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result1 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
var idVal int64
|
||||
_, ok := mapValue1["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue1["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result1.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case2: one record, create from *map[string]interface{}
|
||||
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue2["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result2 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
_, ok = mapValue2["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue2["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result2.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: records
|
||||
values := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
|
||||
}
|
||||
|
||||
beforeLen := len(values)
|
||||
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
// mariadb with returning, values will be appended with id map
|
||||
if len(values) == beforeLen*2 {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
v, ok := values[i]["id"]
|
||||
if !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
if int64(result.ID) != v.(int64) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithTable(t *testing.T) {
|
||||
tableDB := DB.Table("users")
|
||||
supportLastInsertID := isMysql() || isSqlite()
|
||||
|
||||
// case 1: create from map[string]interface{}
|
||||
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
|
||||
if err := tableDB.Create(record).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
|
||||
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
|
||||
}
|
||||
|
||||
// case 2: create from *map[string]interface{}
|
||||
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
|
||||
tableDB2 := DB.Table("users")
|
||||
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
if _, ok := record1["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res1 map[string]interface{}
|
||||
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: create from []map[string]interface{}
|
||||
records := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_table_2", "age": 19},
|
||||
{"name": "create_from_map_with_table_3", "age": 20},
|
||||
}
|
||||
|
||||
tableDB = DB.Table("users")
|
||||
if err := tableDB.Create(&records).Error; err != nil {
|
||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res2 map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
var res3 map[string]interface{}
|
||||
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
|
||||
}
|
||||
|
||||
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
|
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres support returning
|
||||
// only sqlite, postgres, sqlserver support returning
|
||||
func TestSoftDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,7 @@ services:
|
||||
mysql:
|
||||
image: 'mysql/mysql-server:latest'
|
||||
ports:
|
||||
- 9910:3306
|
||||
- "9910:3306"
|
||||
environment:
|
||||
- MYSQL_DATABASE=gorm
|
||||
- MYSQL_USER=gorm
|
||||
@ -13,7 +13,7 @@ services:
|
||||
postgres:
|
||||
image: 'postgres:latest'
|
||||
ports:
|
||||
- 9920:5432
|
||||
- "9920:5432"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- POSTGRES_DB=gorm
|
||||
@ -22,10 +22,16 @@ services:
|
||||
mssql:
|
||||
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
||||
ports:
|
||||
- 9930:1433
|
||||
- "9930:1433"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- ACCEPT_EULA=Y
|
||||
- SA_PASSWORD=LoremIpsum86
|
||||
- MSSQL_DB=gorm
|
||||
- MSSQL_USER=gorm
|
||||
- MSSQL_PASSWORD=LoremIpsum86
|
||||
tidb:
|
||||
image: 'pingcap/tidb:v6.5.0'
|
||||
ports:
|
||||
- "9940:4000"
|
||||
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
|
||||
|
@ -4,7 +4,9 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -36,7 +38,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
|
||||
type EngadgetPost struct {
|
||||
BasePost BasePost `gorm:"Embedded"`
|
||||
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
ImageUrl string
|
||||
}
|
||||
|
||||
@ -74,13 +76,26 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
|
||||
var egNews EngadgetPost
|
||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else if egNews.BasePost.Title != "engadget_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
var egPosts []EngadgetPost
|
||||
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
|
||||
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
|
||||
}
|
||||
expectAuthors := []string{"Edward", "George"}
|
||||
for i, post := range egPosts {
|
||||
t.Log(i, post.Author)
|
||||
if want := expectAuthors[i]; post.Author.Name != want {
|
||||
t.Errorf("expected author %s got %s", want, post.Author.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
@ -90,9 +105,21 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
URL string
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
Age int
|
||||
Content Content
|
||||
ContentPtr *Content
|
||||
Birthday time.Time
|
||||
BirthdayPtr *time.Time
|
||||
}
|
||||
|
||||
type HNPost struct {
|
||||
*BasePost
|
||||
Upvotes int32
|
||||
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&HNPost{})
|
||||
@ -110,6 +137,52 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
if hnPost.Title != "embedded_pointer_type" {
|
||||
t.Errorf("Should find correct value for embedded pointer type")
|
||||
}
|
||||
|
||||
if hnPost.Author != nil {
|
||||
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
|
||||
}
|
||||
|
||||
now := time.Now().Round(time.Second)
|
||||
NewPost := HNPost{
|
||||
BasePost: &BasePost{Title: "embedded_pointer_type2"},
|
||||
Author: &Author{
|
||||
Name: "test",
|
||||
Content: Content{"test"},
|
||||
ContentPtr: nil,
|
||||
Birthday: now,
|
||||
BirthdayPtr: nil,
|
||||
},
|
||||
}
|
||||
DB.Create(&NewPost)
|
||||
|
||||
hnPost = HNPost{}
|
||||
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
|
||||
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
||||
}
|
||||
|
||||
if hnPost.Title != NewPost.Title {
|
||||
t.Errorf("Should find correct value for embedded pointer type")
|
||||
}
|
||||
|
||||
if hnPost.Author.Name != NewPost.Author.Name {
|
||||
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
|
||||
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
|
||||
}
|
||||
|
||||
if hnPost.Author.ContentPtr != nil {
|
||||
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
|
||||
}
|
||||
|
||||
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
|
||||
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
|
||||
}
|
||||
|
||||
if hnPost.Author.BirthdayPtr != nil {
|
||||
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
|
||||
}
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
@ -117,18 +190,26 @@ type Content struct {
|
||||
}
|
||||
|
||||
func (c Content) Value() (driver.Value, error) {
|
||||
return json.Marshal(c)
|
||||
// mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
|
||||
b, err := json.Marshal(c)
|
||||
return string(b[:]), err
|
||||
}
|
||||
|
||||
func (c *Content) Scan(src interface{}) error {
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
|
||||
var value Content
|
||||
if err := json.Unmarshal(b, &value); err != nil {
|
||||
return err
|
||||
str, ok := src.(string)
|
||||
if !ok {
|
||||
byt, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
if err := json.Unmarshal(byt, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal([]byte(str), &value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*c = value
|
||||
@ -155,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEmbeddedRelations(t *testing.T) {
|
||||
type EmbUser struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Age uint
|
||||
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
|
||||
}
|
||||
|
||||
type AdvancedUser struct {
|
||||
User `gorm:"embedded"`
|
||||
EmbUser `gorm:"embedded"`
|
||||
Advanced bool
|
||||
}
|
||||
|
||||
@ -168,3 +256,29 @@ func TestEmbeddedRelations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedTagSetting(t *testing.T) {
|
||||
type Tag1 struct {
|
||||
Id int64 `gorm:"autoIncrement"`
|
||||
}
|
||||
type Tag2 struct {
|
||||
Id int64
|
||||
}
|
||||
|
||||
type EmbeddedTag struct {
|
||||
Tag1 Tag1 `gorm:"Embedded;"`
|
||||
Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"`
|
||||
Name string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&EmbeddedTag{})
|
||||
err := DB.Migrator().AutoMigrate(&EmbeddedTag{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
t1 := EmbeddedTag{Name: "embedded_tag"}
|
||||
err = DB.Save(&t1).Error
|
||||
AssertEqual(t, err, nil)
|
||||
if t1.Tag1.Id == 0 {
|
||||
t.Errorf("embedded struct's primary field should be rewrited")
|
||||
}
|
||||
}
|
||||
|
111
tests/error_translator_test.go
Normal file
111
tests/error_translator_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
|
||||
// it shouldn't translate error when the TranslateError flag is false
|
||||
translatedErr := errors.New("translated error")
|
||||
untranslatedErr := errors.New("some random error")
|
||||
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
|
||||
|
||||
err := db.AddError(untranslatedErr)
|
||||
if !errors.Is(err, untranslatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
|
||||
}
|
||||
|
||||
// it should translate error when the TranslateError flag is true
|
||||
db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true})
|
||||
|
||||
err = db.AddError(untranslatedErr)
|
||||
if !errors.Is(err, translatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}); err != nil {
|
||||
t.Fatalf("failed to migrate cities table, got error: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create record: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
type Museum struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
CityID uint
|
||||
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{}, &Museum{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
|
||||
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
|
||||
}
|
||||
|
||||
city := City{Name: "Amsterdam"}
|
||||
|
||||
err = db.Create(&city).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create city: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create museum: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
|
||||
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
|
||||
}
|
||||
}
|
41
tests/go.mod
41
tests/go.mod
@ -1,18 +1,39 @@
|
||||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.14
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.5
|
||||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
|
||||
gorm.io/driver/mysql v1.3.3
|
||||
gorm.io/driver/postgres v1.3.5
|
||||
gorm.io/driver/sqlite v1.3.2
|
||||
gorm.io/driver/sqlserver v1.3.2
|
||||
gorm.io/gorm v1.23.4
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.9.0
|
||||
gorm.io/driver/mysql v1.5.6
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.5
|
||||
gorm.io/driver/sqlserver v1.5.3
|
||||
gorm.io/gorm v1.25.8
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.7.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
||||
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
|
||||
|
@ -3,9 +3,19 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc
|
||||
_, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
if err == nil {
|
||||
t.Fatalf("should returns error but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||
dialect := DB.Dialector.Name()
|
||||
switch dialect {
|
||||
|
@ -1,12 +1,15 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -20,6 +23,7 @@ type Config struct {
|
||||
Languages int
|
||||
Friends int
|
||||
NamedPet bool
|
||||
Tools int
|
||||
}
|
||||
|
||||
func GetUser(name string, config Config) *User {
|
||||
@ -44,6 +48,10 @@ func GetUser(name string, config Config) *User {
|
||||
user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
|
||||
}
|
||||
|
||||
for i := 0; i < config.Tools; i++ {
|
||||
user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)})
|
||||
}
|
||||
|
||||
if config.Company {
|
||||
user.Company = Company{Name: "company-" + name}
|
||||
}
|
||||
@ -73,13 +81,22 @@ func GetUser(name string, config Config) *User {
|
||||
return &user
|
||||
}
|
||||
|
||||
func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) {
|
||||
doCheckPet(t, pet, expect, true)
|
||||
}
|
||||
|
||||
func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
||||
doCheckPet(t, pet, expect, false)
|
||||
}
|
||||
|
||||
func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) {
|
||||
if pet.ID != 0 {
|
||||
var newPet Pet
|
||||
if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil {
|
||||
if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||
AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,17 +109,27 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
||||
}
|
||||
}
|
||||
|
||||
func CheckUserUnscoped(t *testing.T, user User, expect User) {
|
||||
doCheckUser(t, user, expect, true)
|
||||
}
|
||||
|
||||
func CheckUser(t *testing.T, user User, expect User) {
|
||||
doCheckUser(t, user, expect, false)
|
||||
}
|
||||
|
||||
func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||
if user.ID != 0 {
|
||||
var newUser User
|
||||
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday",
|
||||
"CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
}
|
||||
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID",
|
||||
"ManagerID", "Active")
|
||||
|
||||
t.Run("Account", func(t *testing.T) {
|
||||
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||
@ -112,8 +139,9 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
t.Errorf("Account's foreign key should be saved")
|
||||
} else {
|
||||
var account Account
|
||||
DB.First(&account, "user_id = ?", user.ID)
|
||||
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||
db(unscoped).First(&account, "user_id = ?", user.ID)
|
||||
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID",
|
||||
"Number")
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -135,7 +163,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
if pet == nil || expect.Pets[idx] == nil {
|
||||
t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet)
|
||||
} else {
|
||||
CheckPet(t, *pet, *expect.Pets[idx])
|
||||
doCheckPet(t, *pet, *expect.Pets[idx], unscoped)
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -172,8 +200,11 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
t.Errorf("Manager's foreign key should be saved")
|
||||
} else {
|
||||
var manager User
|
||||
DB.First(&manager, "id = ?", *user.ManagerID)
|
||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
db(unscoped).First(&manager, "id = ?", *user.ManagerID)
|
||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
} else if user.ManagerID != nil {
|
||||
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
||||
@ -194,7 +225,8 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
})
|
||||
|
||||
for idx, team := range user.Team {
|
||||
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
})
|
||||
|
||||
@ -229,7 +261,34 @@ func CheckUser(t *testing.T, user User, expect User) {
|
||||
})
|
||||
|
||||
for idx, friend := range user.Friends {
|
||||
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func tidbSkip(t *testing.T, reason string) {
|
||||
if isTiDB() {
|
||||
t.Skipf("This test case skipped, because of TiDB '%s'", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func isTiDB() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "tidb"
|
||||
}
|
||||
|
||||
func isMysql() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "mysql"
|
||||
}
|
||||
|
||||
func isSqlite() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "sqlite"
|
||||
}
|
||||
|
||||
func db(unscoped bool) *gorm.DB {
|
||||
if unscoped {
|
||||
return DB.Unscoped()
|
||||
} else {
|
||||
return DB
|
||||
}
|
||||
}
|
||||
|
@ -466,8 +466,9 @@ type Product4 struct {
|
||||
|
||||
type ProductItem struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
Product4ID uint
|
||||
Code string
|
||||
Product4ID uint
|
||||
AfterFindCallTimes int
|
||||
}
|
||||
|
||||
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
||||
@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pi *ProductItem) AfterFind(*gorm.DB) error {
|
||||
pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
|
||||
DB.AutoMigrate(&Product4{}, &ProductItem{})
|
||||
@ -498,4 +504,65 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
||||
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
|
||||
t.Errorf("should find product, but got error %v", err)
|
||||
}
|
||||
|
||||
var productWithItem Product4
|
||||
if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil {
|
||||
t.Errorf("should find product, but got error %v", err)
|
||||
}
|
||||
|
||||
if productWithItem.Item.AfterFindCallTimes != 0 {
|
||||
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
|
||||
}
|
||||
}
|
||||
|
||||
type Product5 struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
var beforeUpdateCall int
|
||||
|
||||
func (p *Product5) BeforeUpdate(*gorm.DB) error {
|
||||
beforeUpdateCall = beforeUpdateCall + 1
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateCallbacks(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Product5{})
|
||||
DB.AutoMigrate(&Product5{})
|
||||
|
||||
p := Product5{Name: "unique_code"}
|
||||
DB.Model(&Product5{}).Create(&p)
|
||||
|
||||
err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error
|
||||
if err != nil {
|
||||
t.Fatalf("should update success, but got err %v", err)
|
||||
}
|
||||
if beforeUpdateCall != 1 {
|
||||
t.Fatalf("before update should be called")
|
||||
}
|
||||
|
||||
err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error
|
||||
if !errors.Is(err, gorm.ErrInvalidValue) {
|
||||
t.Fatalf("should got RecordNotFound, but got %v", err)
|
||||
}
|
||||
if beforeUpdateCall != 1 {
|
||||
t.Fatalf("before update should not be called")
|
||||
}
|
||||
|
||||
err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error
|
||||
if err != nil {
|
||||
t.Fatalf("should update success, but got err %v", err)
|
||||
}
|
||||
if beforeUpdateCall != 2 {
|
||||
t.Fatalf("before update should be called")
|
||||
}
|
||||
|
||||
err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error
|
||||
if !errors.Is(err, gorm.ErrInvalidValue) {
|
||||
t.Fatalf("should got RecordNotFound, but got %v", err)
|
||||
}
|
||||
if beforeUpdateCall != 2 {
|
||||
t.Fatalf("before update should not be called")
|
||||
}
|
||||
}
|
||||
|
@ -229,3 +229,174 @@ func TestJoinWithSoftDeleted(t *testing.T) {
|
||||
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInnerJoins(t *testing.T) {
|
||||
user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
|
||||
|
||||
DB.Create(&user)
|
||||
|
||||
var user2 User
|
||||
var err error
|
||||
err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, nil)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
// inner join and NamedPet is nil
|
||||
err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, gorm.ErrRecordNotFound)
|
||||
|
||||
// mixed inner join and left join
|
||||
var user3 User
|
||||
err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error
|
||||
AssertEqual(t, err, nil)
|
||||
CheckUser(t, user3, user)
|
||||
}
|
||||
|
||||
func TestJoinWithSameColumnName(t *testing.T) {
|
||||
user := GetUser("TestJoinWithSameColumnName", Config{
|
||||
Languages: 1,
|
||||
Pets: 1,
|
||||
})
|
||||
DB.Create(user)
|
||||
type UserSpeak struct {
|
||||
UserID uint
|
||||
LanguageCode string
|
||||
}
|
||||
type Result struct {
|
||||
User
|
||||
UserSpeak
|
||||
Language
|
||||
Pet
|
||||
}
|
||||
|
||||
results := make([]Result, 0, 1)
|
||||
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
|
||||
Joins("JOIN languages ON languages.code = user_speaks.language_code").
|
||||
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
|
||||
|
||||
if len(results) == 0 {
|
||||
t.Fatalf("no record find")
|
||||
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
|
||||
t.Fatalf("wrong user id in pet")
|
||||
} else if results[0].Pet.Name != user.Pets[0].Name {
|
||||
t.Fatalf("wrong pet name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinArgsWithDB(t *testing.T) {
|
||||
user := *GetUser("joins-args-db", Config{Pets: 2})
|
||||
DB.Save(&user)
|
||||
|
||||
// test where
|
||||
var user1 User
|
||||
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
|
||||
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test where and omit
|
||||
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
|
||||
var user2 User
|
||||
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
|
||||
AssertEqual(t, user2.NamedPet.Name, "")
|
||||
|
||||
// test where and select
|
||||
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
|
||||
var user3 User
|
||||
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user3.NamedPet.ID, 0)
|
||||
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test select
|
||||
onQuery4 := DB.Select("ID")
|
||||
var user4 User
|
||||
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
if user4.NamedPet.ID == 0 {
|
||||
t.Fatal("Pet ID can not be empty")
|
||||
}
|
||||
AssertEqual(t, user4.NamedPet.Name, "")
|
||||
}
|
||||
|
||||
func TestNestedJoins(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "nested-joins-1",
|
||||
Manager: &User{
|
||||
Name: "nested-joins-manager-1",
|
||||
Company: Company{
|
||||
Name: "nested-joins-manager-company-1",
|
||||
},
|
||||
NamedPet: &Pet{
|
||||
Name: "nested-joins-manager-namepet-1",
|
||||
Toy: Toy{
|
||||
Name: "nested-joins-manager-namepet-toy-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}},
|
||||
},
|
||||
{
|
||||
Name: "nested-joins-2",
|
||||
Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
|
||||
var userIDs []uint
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
var users2 []User
|
||||
if err := DB.
|
||||
Joins("Manager").
|
||||
Joins("Manager.Company").
|
||||
Joins("Manager.NamedPet").
|
||||
Joins("Manager.NamedPet.Toy").
|
||||
Joins("NamedPet").
|
||||
Joins("NamedPet.Toy").
|
||||
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||
} else if len(users2) != len(users) {
|
||||
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||
}
|
||||
|
||||
sort.Slice(users2, func(i, j int) bool {
|
||||
return users2[i].ID > users2[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].ID > users[j].ID
|
||||
})
|
||||
|
||||
for idx, user := range users {
|
||||
// user
|
||||
CheckUser(t, user, users2[idx])
|
||||
if users2[idx].Manager == nil {
|
||||
t.Fatalf("Failed to load Manager")
|
||||
}
|
||||
// manager
|
||||
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||
// user pet
|
||||
if users2[idx].NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||
// manager pet
|
||||
if users2[idx].Manager.NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -7,8 +7,60 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Yasuo struct {
|
||||
ID string `gorm:"default:gen_random_uuid()"`
|
||||
Name string
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Yasuo{})
|
||||
|
||||
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
yasuo := Yasuo{Name: "jinzhu"}
|
||||
if err := DB.Create(&yasuo).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
if yasuo.ID == "" {
|
||||
t.Fatal("should be able to has ID, but got zero value")
|
||||
}
|
||||
|
||||
var result Yasuo
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
yasuo.Name = "jinzhu1"
|
||||
if err := DB.Save(&yasuo).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgres(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
@ -60,16 +112,55 @@ func TestPostgres(t *testing.T) {
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("log_usage")
|
||||
|
||||
if err := DB.Exec(`
|
||||
CREATE TABLE public.log_usage (
|
||||
log_id bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||
SEQUENCE NAME public.log_usage_log_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("failed to create table, got error %v", err)
|
||||
}
|
||||
|
||||
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get columns, got error %v", err)
|
||||
}
|
||||
|
||||
hasLogID := false
|
||||
for _, column := range columns {
|
||||
if column.Name() == "log_id" {
|
||||
hasLogID = true
|
||||
autoIncrement, ok := column.AutoIncrement()
|
||||
if !ok || !autoIncrement {
|
||||
t.Fatalf("column log_id should be auto incrementment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLogID {
|
||||
t.Fatalf("failed to found column log_id")
|
||||
}
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||
Title string
|
||||
Categories []*Category `gorm:"Many2Many:post_categories"`
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||
Title string
|
||||
Posts []*Post `gorm:"Many2Many:post_categories"`
|
||||
}
|
||||
@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresOnConstraint(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Thing struct {
|
||||
gorm.Model
|
||||
SomeID string
|
||||
OtherID string
|
||||
Data string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Thing{})
|
||||
DB.Migrator().CreateTable(&Thing{})
|
||||
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
thing := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something",
|
||||
}
|
||||
|
||||
DB.Create(&thing)
|
||||
|
||||
thing2 := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something else",
|
||||
}
|
||||
|
||||
result := DB.Clauses(clause.OnConflict{
|
||||
OnConstraint: "some_id_other_id_unique",
|
||||
UpdateAll: true,
|
||||
}).Create(&thing2)
|
||||
if result.Error != nil {
|
||||
t.Errorf("creating second thing: %v", result.Error)
|
||||
}
|
||||
|
||||
var things []Thing
|
||||
if err := DB.Find(&things).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(things) > 1 {
|
||||
t.Errorf("expected 1 thing got more")
|
||||
}
|
||||
}
|
||||
|
||||
type CompanyNew struct {
|
||||
ID int
|
||||
Name int
|
||||
}
|
||||
|
||||
func TestAlterColumnDataType(t *testing.T) {
|
||||
DB.AutoMigrate(Company{})
|
||||
|
||||
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||
}
|
||||
|
||||
DB.AutoMigrate(Company{})
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -269,3 +271,242 @@ func TestPreloadWithDiffModel(t *testing.T) {
|
||||
|
||||
CheckUser(t, user, result.User)
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithUnscoped(t *testing.T) {
|
||||
user := *GetUser("nested_preload", Config{Pets: 1})
|
||||
pet := user.Pets[0]
|
||||
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)}
|
||||
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
var user2 User
|
||||
DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID)
|
||||
CheckUser(t, user2, user)
|
||||
|
||||
DB.Delete(&pet)
|
||||
|
||||
var user3 User
|
||||
DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID)
|
||||
if len(user3.Pets) != 0 {
|
||||
t.Fatalf("User.Pet[0] was deleted and should not exist.")
|
||||
}
|
||||
|
||||
var user4 *User
|
||||
DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID)
|
||||
if len(user4.Pets) != 0 {
|
||||
t.Fatalf("User.Pet[0] was deleted and should not exist.")
|
||||
}
|
||||
|
||||
var user5 User
|
||||
DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID)
|
||||
CheckUserUnscoped(t, user5, user)
|
||||
|
||||
var user6 *User
|
||||
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
||||
CheckUserUnscoped(t, *user6, user)
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||
type (
|
||||
Preload struct {
|
||||
ID uint
|
||||
Value string
|
||||
NestedID uint
|
||||
}
|
||||
Join struct {
|
||||
ID uint
|
||||
Value string
|
||||
NestedID uint
|
||||
}
|
||||
Nested struct {
|
||||
ID uint
|
||||
Preloads []*Preload
|
||||
Join Join
|
||||
ValueID uint
|
||||
}
|
||||
Value struct {
|
||||
ID uint
|
||||
Name string
|
||||
Nested Nested
|
||||
}
|
||||
)
|
||||
|
||||
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
|
||||
value := Value{
|
||||
Name: "value",
|
||||
Nested: Nested{
|
||||
Preloads: []*Preload{
|
||||
{Value: "p1"}, {Value: "p2"},
|
||||
},
|
||||
Join: Join{Value: "j1"},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&value).Error; err != nil {
|
||||
t.Errorf("failed to create value, got err: %v", err)
|
||||
}
|
||||
|
||||
var find1 Value
|
||||
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find1, value)
|
||||
|
||||
var find2 Value
|
||||
// Joins will automatically add Nested queries.
|
||||
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find2, value)
|
||||
|
||||
var finds []Value
|
||||
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
require.Len(t, finds, 1)
|
||||
AssertEqual(t, finds[0], value)
|
||||
}
|
||||
|
||||
func TestEmbedPreload(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
type EmbeddedAddress struct {
|
||||
ID int
|
||||
Name string
|
||||
CountryID *int
|
||||
Country *Country
|
||||
}
|
||||
type NestedAddress struct {
|
||||
EmbeddedAddress
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||
AddressID *int
|
||||
Address *EmbeddedAddress
|
||||
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||
|
||||
org := Org{
|
||||
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
|
||||
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
|
||||
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
|
||||
NestedAddress: NestedAddress{
|
||||
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&org).Error; err != nil {
|
||||
t.Errorf("failed to create org, got err: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
preloads map[string][]interface{}
|
||||
expect Org
|
||||
}{
|
||||
{
|
||||
name: "address country",
|
||||
preloads: map[string][]interface{}{"Address.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: EmbeddedAddress{
|
||||
ID: org.PostalAddress.ID,
|
||||
Name: org.PostalAddress.Name,
|
||||
CountryID: org.PostalAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: org.Address,
|
||||
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||
ID: org.NestedAddress.ID,
|
||||
Name: org.NestedAddress.Name,
|
||||
CountryID: org.NestedAddress.CountryID,
|
||||
Country: nil,
|
||||
}},
|
||||
},
|
||||
}, {
|
||||
name: "postal address country",
|
||||
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: org.PostalAddress,
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: nil,
|
||||
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||
ID: org.NestedAddress.ID,
|
||||
Name: org.NestedAddress.Name,
|
||||
CountryID: org.NestedAddress.CountryID,
|
||||
Country: nil,
|
||||
}},
|
||||
},
|
||||
}, {
|
||||
name: "nested address country",
|
||||
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: EmbeddedAddress{
|
||||
ID: org.PostalAddress.ID,
|
||||
Name: org.PostalAddress.Name,
|
||||
CountryID: org.PostalAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: nil,
|
||||
NestedAddress: org.NestedAddress,
|
||||
},
|
||||
}, {
|
||||
name: "associations",
|
||||
preloads: map[string][]interface{}{
|
||||
clause.Associations: {},
|
||||
// clause.Associations won’t preload nested associations
|
||||
"Address.Country": {},
|
||||
},
|
||||
expect: org,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := Org{}
|
||||
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
|
||||
for name, args := range test.preloads {
|
||||
tx = tx.Preload(name, args...)
|
||||
}
|
||||
if err := tx.Find(&actual).Error; err != nil {
|
||||
t.Errorf("failed to find org, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, actual, test.expect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -88,3 +90,80 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||
}
|
||||
tx2.Commit()
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
user := User{Name: "jinzhu"}
|
||||
tx.Create(&user)
|
||||
|
||||
var result User
|
||||
tx.First(&result)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 2)
|
||||
for _, stmt := range conn.Stmts {
|
||||
if stmt == nil {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
}
|
||||
|
||||
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
user := User{Name: "jinzhu"}
|
||||
|
||||
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
|
||||
return errors.New("test")
|
||||
}); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.First(&result, user.ID).Error; err == nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
@ -216,6 +217,30 @@ func TestFind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// test array
|
||||
var models2 [3]User
|
||||
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil {
|
||||
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
|
||||
} else {
|
||||
for idx, user := range users {
|
||||
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
CheckUser(t, models2[idx], user)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// test smaller array
|
||||
var models3 [2]User
|
||||
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil {
|
||||
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
|
||||
} else {
|
||||
for idx, user := range users[:2] {
|
||||
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||
CheckUser(t, models3[idx], user)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var none []User
|
||||
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||
@ -257,7 +282,7 @@ func TestFindInBatches(t *testing.T) {
|
||||
totalBatch int
|
||||
)
|
||||
|
||||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
totalBatch += batch
|
||||
|
||||
if tx.RowsAffected != 2 {
|
||||
@ -273,7 +298,7 @@ func TestFindInBatches(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := tx.Save(results).Error; err != nil {
|
||||
t.Errorf("failed to save users, got error %v", err)
|
||||
t.Fatalf("failed to save users, got error %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -384,6 +409,13 @@ func TestFindInBatchesWithError(t *testing.T) {
|
||||
if totalBatch != 0 {
|
||||
t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch)
|
||||
}
|
||||
|
||||
if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||
totalBatch += batch
|
||||
return nil
|
||||
}); result.Error != gorm.ErrPrimaryKeyRequired {
|
||||
t.Fatal("expected errors to have occurred, but nothing happened")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSmallerStruct(t *testing.T) {
|
||||
@ -522,6 +554,11 @@ func TestNot(t *testing.T) {
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{})
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotWithAllFields(t *testing.T) {
|
||||
@ -627,6 +664,18 @@ func TestOrWithAllFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type Int64 int64
|
||||
|
||||
func (v Int64) Value() (driver.Value, error) {
|
||||
return v - 1, nil
|
||||
}
|
||||
|
||||
func (f *Int64) Scan(v interface{}) error {
|
||||
y := v.(int64)
|
||||
*f = Int64(y + 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPluck(t *testing.T) {
|
||||
users := []*User{
|
||||
GetUser("pluck-user1", Config{}),
|
||||
@ -654,6 +703,11 @@ func TestPluck(t *testing.T) {
|
||||
t.Errorf("got error when pluck id: %v", err)
|
||||
}
|
||||
|
||||
var ids2 []Int64
|
||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil {
|
||||
t.Errorf("got error when pluck id: %v", err)
|
||||
}
|
||||
|
||||
for idx, name := range names {
|
||||
if name != users[idx].Name {
|
||||
t.Errorf("Unexpected result on pluck name, got %+v", names)
|
||||
@ -666,6 +720,12 @@ func TestPluck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for idx, id := range ids2 {
|
||||
if int(id) != int(users[idx].ID+1) {
|
||||
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||
}
|
||||
}
|
||||
|
||||
var times []time.Time
|
||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||
t.Errorf("got error when pluck time: %v", err)
|
||||
@ -1063,12 +1123,12 @@ func TestSearchWithStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
@ -1258,3 +1318,113 @@ func TestQueryScannerWithSingleColumn(t *testing.T) {
|
||||
|
||||
AssertEqual(t, result2.data, 20)
|
||||
}
|
||||
|
||||
func TestQueryResetNullValue(t *testing.T) {
|
||||
type QueryResetItem struct {
|
||||
ID string `gorm:"type:varchar(5)"`
|
||||
Name string
|
||||
}
|
||||
|
||||
type QueryResetNullValue struct {
|
||||
ID int
|
||||
Name string `gorm:"default:NULL"`
|
||||
Flag bool `gorm:"default:NULL"`
|
||||
Number1 int64 `gorm:"default:NULL"`
|
||||
Number2 uint64 `gorm:"default:NULL"`
|
||||
Number3 float64 `gorm:"default:NULL"`
|
||||
Now *time.Time `gorm:"defalut:NULL"`
|
||||
Item1Id string
|
||||
Item1 *QueryResetItem `gorm:"references:ID"`
|
||||
Item2Id string
|
||||
Item2 *QueryResetItem `gorm:"references:ID"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{})
|
||||
DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{})
|
||||
|
||||
now := time.Now()
|
||||
q1 := QueryResetNullValue{
|
||||
Name: "name",
|
||||
Flag: true,
|
||||
Number1: 100,
|
||||
Number2: 200,
|
||||
Number3: 300.1,
|
||||
Now: &now,
|
||||
Item1: &QueryResetItem{
|
||||
ID: "u_1_1",
|
||||
Name: "item_1_1",
|
||||
},
|
||||
Item2: &QueryResetItem{
|
||||
ID: "u_1_2",
|
||||
Name: "item_1_2",
|
||||
},
|
||||
}
|
||||
|
||||
q2 := QueryResetNullValue{
|
||||
Item1: &QueryResetItem{
|
||||
ID: "u_2_1",
|
||||
Name: "item_2_1",
|
||||
},
|
||||
Item2: &QueryResetItem{
|
||||
ID: "u_2_2",
|
||||
Name: "item_2_2",
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
err = DB.Create(&q1).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to create:%v", err)
|
||||
}
|
||||
|
||||
err = DB.Create(&q2).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to create:%v", err)
|
||||
}
|
||||
|
||||
var qs []QueryResetNullValue
|
||||
err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find:%v", err)
|
||||
}
|
||||
|
||||
if len(qs) != 2 {
|
||||
t.Fatalf("find count not equal:%d", len(qs))
|
||||
}
|
||||
|
||||
AssertEqual(t, q1, qs[0])
|
||||
AssertEqual(t, q2, qs[1])
|
||||
}
|
||||
|
||||
func TestQueryError(t *testing.T) {
|
||||
type P struct{}
|
||||
var p1 P
|
||||
err := DB.Take(&p1, 1).Error
|
||||
AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired)
|
||||
|
||||
var p2 interface{}
|
||||
|
||||
err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{
|
||||
Table: clause.CurrentTable, Name: clause.PrimaryKey,
|
||||
}, Value: 1}).Scan(&p2).Error
|
||||
AssertEqual(t, err, gorm.ErrModelValueRequired)
|
||||
}
|
||||
|
||||
func TestQueryScanToArray(t *testing.T) {
|
||||
err := DB.Create(&User{Name: "testname1", Age: 10}).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users := [2]*User{{Name: "1"}, {Name: "2"}}
|
||||
err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if users[0] == nil || users[0].Name != "testname1" {
|
||||
t.Error("users[0] not covere")
|
||||
}
|
||||
if users[1] != nil {
|
||||
t.Error("users[1] should be empty")
|
||||
}
|
||||
}
|
||||
|
@ -214,4 +214,29 @@ func TestScanToEmbedded(t *testing.T) {
|
||||
if !addressMatched {
|
||||
t.Errorf("Failed, no address matched")
|
||||
}
|
||||
|
||||
personDupField := Person{ID: person1.ID}
|
||||
if err := DB.Select("people.id, people.*").
|
||||
First(&personDupField).Error; err != nil {
|
||||
t.Errorf("Failed to run join query, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, person1, personDupField)
|
||||
|
||||
user := User{
|
||||
Name: "TestScanToEmbedded_1",
|
||||
Manager: &User{
|
||||
Name: "TestScanToEmbedded_1_m1",
|
||||
Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"},
|
||||
},
|
||||
}
|
||||
DB.Create(&user)
|
||||
|
||||
type UserScan struct {
|
||||
ID uint
|
||||
Name string
|
||||
ManagerID *uint
|
||||
}
|
||||
var user2 UserScan
|
||||
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
||||
return errors.New("Too short")
|
||||
}
|
||||
|
||||
*data = b[3:]
|
||||
*data = append((*data)[0:], b[3:]...)
|
||||
return nil
|
||||
} else if s, ok := value.(string); ok {
|
||||
*data = []byte(s)[3:]
|
||||
*data = []byte(s[3:])
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -72,3 +72,58 @@ func TestScopes(t *testing.T) {
|
||||
t.Errorf("select max(id)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexScopes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryFn func(tx *gorm.DB) *gorm.DB
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "depth_1",
|
||||
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.Where(DB.Or("b = 2").Or("c = 3"))
|
||||
},
|
||||
).Find(&Language{})
|
||||
},
|
||||
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
|
||||
}, {
|
||||
name: "depth_1_pre_cond",
|
||||
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Where("z = 0").Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.Or(DB.Where("b = 2").Or("c = 3"))
|
||||
},
|
||||
).Find(&Language{})
|
||||
},
|
||||
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
|
||||
}, {
|
||||
name: "depth_2",
|
||||
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
|
||||
func(d *gorm.DB) *gorm.DB {
|
||||
return d.
|
||||
Or(DB.Scopes(
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
|
||||
)).
|
||||
Or("c = 3")
|
||||
},
|
||||
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
|
||||
).Find(&Language{})
|
||||
},
|
||||
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -16,13 +16,40 @@ import (
|
||||
|
||||
type SerializerStruct struct {
|
||||
gorm.Model
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
EncryptedString EncryptedString
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Roles2 *Roles `gorm:"serializer:json"`
|
||||
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type
|
||||
UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type
|
||||
CustomSerializerString string `gorm:"serializer:custom"`
|
||||
EncryptedString EncryptedString
|
||||
}
|
||||
|
||||
type SerializerPostgresStruct struct {
|
||||
gorm.Model
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Roles2 *Roles `gorm:"serializer:json"`
|
||||
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type
|
||||
UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type
|
||||
CustomSerializerString string `gorm:"serializer:custom"`
|
||||
EncryptedString EncryptedString
|
||||
}
|
||||
|
||||
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
|
||||
|
||||
func adaptorSerializerModel(s *SerializerStruct) interface{} {
|
||||
if DB.Dialector.Name() == "postgres" {
|
||||
sps := SerializerPostgresStruct(*s)
|
||||
return &sps
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type Roles []string
|
||||
@ -52,9 +79,34 @@ func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst re
|
||||
return "hello" + string(es), nil
|
||||
}
|
||||
|
||||
type CustomSerializer struct {
|
||||
prefix []byte
|
||||
}
|
||||
|
||||
func NewCustomSerializer(prefix string) *CustomSerializer {
|
||||
return &CustomSerializer{prefix: []byte(prefix)}
|
||||
}
|
||||
|
||||
func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
switch value := dbValue.(type) {
|
||||
case []byte:
|
||||
err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix))
|
||||
case string:
|
||||
err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix)))
|
||||
default:
|
||||
err = fmt.Errorf("unsupported data %#v", dbValue)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil
|
||||
}
|
||||
|
||||
func TestSerializer(t *testing.T) {
|
||||
DB.Migrator().DropTable(&SerializerStruct{})
|
||||
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
||||
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||
}
|
||||
|
||||
@ -74,12 +126,42 @@ func TestSerializer(t *testing.T) {
|
||||
Location: "Kenmawr",
|
||||
IsIntern: false,
|
||||
},
|
||||
CustomSerializerString: "world",
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
t.Fatalf("failed to create data, got error %v", err)
|
||||
}
|
||||
|
||||
var result SerializerStruct
|
||||
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, data)
|
||||
|
||||
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializerZeroValue(t *testing.T) {
|
||||
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||
}
|
||||
|
||||
data := SerializerStruct{}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
t.Fatalf("failed to create data, got error %v", err)
|
||||
}
|
||||
|
||||
var result SerializerStruct
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
@ -87,11 +169,19 @@ func TestSerializer(t *testing.T) {
|
||||
|
||||
AssertEqual(t, result, data)
|
||||
|
||||
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||
DB.Migrator().DropTable(&SerializerStruct{})
|
||||
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
||||
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||
}
|
||||
|
||||
@ -109,6 +199,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||
Location: "Shadyside",
|
||||
IsIntern: false,
|
||||
},
|
||||
CustomSerializerString: "world",
|
||||
}
|
||||
|
||||
// first time insert record
|
||||
@ -123,7 +214,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||
}
|
||||
AssertEqual(t, result, out)
|
||||
|
||||
//update record
|
||||
// update record
|
||||
data.Roles = append(data.Roles, "r3")
|
||||
data.JobInfo.Location = "Gates Hillman Complex"
|
||||
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
@ -39,6 +40,11 @@ func TestSoftDelete(t *testing.T) {
|
||||
t.Fatalf("invalid sql generated, got %v", sql)
|
||||
}
|
||||
|
||||
sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String()
|
||||
if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||
t.Errorf("Table with escape character, got %v", sql)
|
||||
}
|
||||
|
||||
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
@ -93,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) {
|
||||
t.Fatalf("invalid sql generated, got %v", actualSQL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDeleteZeroValue(t *testing.T) {
|
||||
type SoftDeleteBook struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pages uint
|
||||
DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"`
|
||||
}
|
||||
DB.Migrator().DropTable(&SoftDeleteBook{})
|
||||
if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate soft delete table")
|
||||
}
|
||||
|
||||
book := SoftDeleteBook{Name: "jinzhu", Pages: 10}
|
||||
DB.Save(&book)
|
||||
|
||||
var count int64
|
||||
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
|
||||
}
|
||||
|
||||
var pages uint
|
||||
if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||
t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages)
|
||||
}
|
||||
|
||||
if err := DB.Delete(&book).Error; err != nil {
|
||||
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
||||
}
|
||||
|
||||
zeroTime, _ := now.Parse("1970-01-01 00:00:01")
|
||||
if book.DeletedAt.Time.Equal(zeroTime) {
|
||||
t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt)
|
||||
}
|
||||
|
||||
if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
|
||||
count = 0
|
||||
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
|
||||
}
|
||||
|
||||
pages = 0
|
||||
if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 {
|
||||
t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err)
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil {
|
||||
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
||||
}
|
||||
|
||||
count = 0
|
||||
if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
|
||||
}
|
||||
|
||||
pages = 0
|
||||
if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages)
|
||||
}
|
||||
|
||||
DB.Unscoped().Delete(&book)
|
||||
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
t.Errorf("Can't find permanently deleted record")
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user