Merge branch 'master' into detached
This commit is contained in:
commit
31d2ea8ca4
2
.github/workflows/invalid_question.yml
vendored
2
.github/workflows/invalid_question.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
|||||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
|||||||
name: Label issues and pull requests
|
name: Label issues and pull requests
|
||||||
steps:
|
steps:
|
||||||
- name: check out
|
- name: check out
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: labeler
|
- name: labeler
|
||||||
uses: jinzhu/super-labeler-action@develop
|
uses: jinzhu/super-labeler-action@develop
|
||||||
|
|||||||
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
|||||||
2
.github/workflows/reviewdog.yml
vendored
2
.github/workflows/reviewdog.yml
vendored
@ -6,7 +6,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: reviewdog/action-golangci-lint@v2
|
uses: reviewdog/action-golangci-lint@v2
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v8
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||||
|
|||||||
111
.github/workflows/tests.yml
vendored
111
.github/workflows/tests.yml
vendored
@ -16,21 +16,21 @@ jobs:
|
|||||||
sqlite:
|
sqlite:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run in windows OS
|
platform: [ubuntu-latest] # can not run in windows OS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
@ -41,8 +41,8 @@ jobs:
|
|||||||
mysql:
|
mysql:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
dbversion: ['mysql:latest', 'mysql:5.7']
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest]
|
platform: [ubuntu-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -65,16 +65,15 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
@ -82,11 +81,54 @@ jobs:
|
|||||||
- name: Tests
|
- name: Tests
|
||||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||||
|
|
||||||
|
mariadb:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
dbversion: [ 'mariadb:latest' ]
|
||||||
|
go: ['1.21', '1.20', '1.19']
|
||||||
|
platform: [ ubuntu-latest ]
|
||||||
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
services:
|
||||||
|
mysql:
|
||||||
|
image: ${{ matrix.dbversion }}
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
ports:
|
||||||
|
- 9910:3306
|
||||||
|
options: >-
|
||||||
|
--health-cmd "mariadb-admin ping -ugorm -pgorm"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-start-period 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 10
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Set up Go 1.x
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
|
- name: Check out code into the Go module directory
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: go mod package cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||||
|
|
||||||
postgres:
|
postgres:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -109,15 +151,15 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
@ -128,7 +170,7 @@ jobs:
|
|||||||
sqlserver:
|
sqlserver:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.21', '1.20', '1.19']
|
||||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -152,18 +194,51 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||||
|
|
||||||
|
tidb:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
dbversion: [ 'v6.5.0' ]
|
||||||
|
go: ['1.21', '1.20', '1.19']
|
||||||
|
platform: [ ubuntu-latest ]
|
||||||
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Setup TiDB
|
||||||
|
uses: Icemap/tidb-action@main
|
||||||
|
with:
|
||||||
|
port: 9940
|
||||||
|
version: ${{matrix.dbversion}}
|
||||||
|
|
||||||
|
- name: Set up Go 1.x
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
|
- name: Check out code into the Go module directory
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
|
||||||
|
- name: go mod package cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ coverage.txt
|
|||||||
_book
|
_book
|
||||||
.idea
|
.idea
|
||||||
vendor
|
vendor
|
||||||
|
.vscode
|
||||||
|
|||||||
@ -9,3 +9,12 @@ linters:
|
|||||||
- prealloc
|
- prealloc
|
||||||
- unconvert
|
- unconvert
|
||||||
- unparam
|
- unparam
|
||||||
|
- goimports
|
||||||
|
- whitespace
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
whitespace:
|
||||||
|
multi-func: true
|
||||||
|
goimports:
|
||||||
|
local-prefixes: gorm.io/gorm
|
||||||
|
|
||||||
|
|||||||
11
README.md
11
README.md
@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
|
|
||||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||||
[](https://github.com/go-gorm/gorm/actions)
|
[](https://github.com/go-gorm/gorm/actions)
|
||||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
|
||||||
[](https://opencollective.com/gorm)
|
|
||||||
[](https://opencollective.com/gorm)
|
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||||
|
|
||||||
@ -30,14 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||||
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||||
|
|
||||||
|
## Contributors
|
||||||
|
|
||||||
|
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
© Jinzhu, 2013~time.Now
|
© Jinzhu, 2013~time.Now
|
||||||
|
|
||||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import (
|
|||||||
type Association struct {
|
type Association struct {
|
||||||
DB *DB
|
DB *DB
|
||||||
Relationship *schema.Relationship
|
Relationship *schema.Relationship
|
||||||
|
Unscope bool
|
||||||
Error error
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association {
|
|||||||
return association
|
return association
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (association *Association) Unscoped() *Association {
|
||||||
|
return &Association{
|
||||||
|
DB: association.DB,
|
||||||
|
Relationship: association.Relationship,
|
||||||
|
Error: association.Error,
|
||||||
|
Unscope: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||||
@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error {
|
|||||||
|
|
||||||
func (association *Association) Replace(values ...interface{}) error {
|
func (association *Association) Replace(values ...interface{}) error {
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
|
reflectValue := association.DB.Statement.ReflectValue
|
||||||
|
rel := association.Relationship
|
||||||
|
|
||||||
|
var oldBelongsToExpr clause.Expression
|
||||||
|
// we have to record the old BelongsTo value
|
||||||
|
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||||
|
var foreignFields []*schema.Field
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||||
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||||
|
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// save associations
|
// save associations
|
||||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||||
return association.Error
|
return association.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// set old associations's foreign key to null
|
// set old associations's foreign key to null
|
||||||
reflectValue := association.DB.Statement.ReflectValue
|
|
||||||
rel := association.Relationship
|
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
|
|
||||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||||
}
|
}
|
||||||
|
if association.Unscope && oldBelongsToExpr != nil {
|
||||||
|
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||||
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
var (
|
var (
|
||||||
primaryFields []*schema.Field
|
primaryFields []*schema.Field
|
||||||
@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
|
|
||||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
if association.Unscope {
|
||||||
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||||
|
} else {
|
||||||
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var (
|
var (
|
||||||
@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
|
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
associationDB := association.DB.Session(&Session{})
|
||||||
|
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||||
@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||||
|
if association.Unscope {
|
||||||
|
var foreignFields []*schema.Field
|
||||||
|
for _, ref := range rel.References {
|
||||||
|
if !ref.OwnPrimaryKey {
|
||||||
|
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||||
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||||
|
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||||
|
}
|
||||||
|
}
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
|
tx := association.DB.Model(model)
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||||
@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
if association.Unscope {
|
||||||
|
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||||
|
} else {
|
||||||
|
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||||
|
}
|
||||||
case schema.Many2Many:
|
case schema.Many2Many:
|
||||||
var (
|
var (
|
||||||
primaryFields, relPrimaryFields []*schema.Field
|
primaryFields, relPrimaryFields []*schema.Field
|
||||||
@ -353,9 +404,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
}
|
}
|
||||||
case schema.HasMany, schema.Many2Many:
|
case schema.HasMany, schema.Many2Many:
|
||||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||||
|
var fieldValue reflect.Value
|
||||||
if clear {
|
if clear {
|
||||||
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
|
||||||
|
} else {
|
||||||
|
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
|
||||||
|
reflect.Copy(fieldValue, oldFieldValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
appendToFieldValues := func(ev reflect.Value) {
|
appendToFieldValues := func(ev reflect.Value) {
|
||||||
@ -507,7 +562,9 @@ func (association *Association) buildCondition() *DB {
|
|||||||
joinStmt.AddClause(queryClause)
|
joinStmt.AddClause(queryClause)
|
||||||
}
|
}
|
||||||
joinStmt.Build("WHERE")
|
joinStmt.Build("WHERE")
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
if len(joinStmt.SQL.String()) > 0 {
|
||||||
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||||
|
|||||||
45
callbacks.go
45
callbacks.go
@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor {
|
|||||||
func (p *processor) Execute(db *DB) *DB {
|
func (p *processor) Execute(db *DB) *DB {
|
||||||
// call scopes
|
// call scopes
|
||||||
for len(db.Statement.scopes) > 0 {
|
for len(db.Statement.scopes) > 0 {
|
||||||
scopes := db.Statement.scopes
|
db = db.executeScopes()
|
||||||
db.Statement.scopes = nil
|
|
||||||
for _, scope := range scopes {
|
|
||||||
db = scope(db)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -93,6 +89,10 @@ func (p *processor) Execute(db *DB) *DB {
|
|||||||
resetBuildClauses = true
|
resetBuildClauses = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||||
|
optimizer.ModifyStatement(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
// assign model values
|
// assign model values
|
||||||
if stmt.Model == nil {
|
if stmt.Model == nil {
|
||||||
stmt.Model = stmt.Dest
|
stmt.Model = stmt.Dest
|
||||||
@ -132,7 +132,11 @@ func (p *processor) Execute(db *DB) *DB {
|
|||||||
|
|
||||||
if stmt.SQL.Len() > 0 {
|
if stmt.SQL.Len() > 0 {
|
||||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||||
|
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||||
|
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||||
|
}
|
||||||
|
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||||
}, db.Error)
|
}, db.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
|||||||
|
|
||||||
func (p *processor) compile() (err error) {
|
func (p *processor) compile() (err error) {
|
||||||
var callbacks []*callback
|
var callbacks []*callback
|
||||||
|
removedMap := map[string]bool{}
|
||||||
for _, callback := range p.callbacks {
|
for _, callback := range p.callbacks {
|
||||||
if callback.match == nil || callback.match(p.db) {
|
if callback.match == nil || callback.match(p.db) {
|
||||||
callbacks = append(callbacks, callback)
|
callbacks = append(callbacks, callback)
|
||||||
}
|
}
|
||||||
|
if callback.remove {
|
||||||
|
removedMap[callback.name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(removedMap) > 0 {
|
||||||
|
callbacks = removeCallbacks(callbacks, removedMap)
|
||||||
}
|
}
|
||||||
p.callbacks = callbacks
|
p.callbacks = callbacks
|
||||||
|
|
||||||
@ -245,8 +257,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
names, sorted []string
|
names, sorted []string
|
||||||
sortCallback func(*callback) error
|
sortCallback func(*callback) error
|
||||||
)
|
)
|
||||||
sort.Slice(cs, func(i, j int) bool {
|
sort.SliceStable(cs, func(i, j int) bool {
|
||||||
return cs[j].before == "*" || cs[j].after == "*"
|
if cs[j].before == "*" && cs[i].before != "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if cs[j].after == "*" && cs[i].after != "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
@ -329,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||||
|
callbacks := make([]*callback, 0, len(cs))
|
||||||
|
for _, callback := range cs {
|
||||||
|
if nameMap[callback.name] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callbacks = append(callbacks, callback)
|
||||||
|
}
|
||||||
|
return callbacks
|
||||||
|
}
|
||||||
|
|||||||
@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
|
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
|
identityMap := map[string]bool{}
|
||||||
for i := 0; i < rValLen; i++ {
|
for i := 0; i < rValLen; i++ {
|
||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||||
|
if !isPtr {
|
||||||
|
rv = rv.Addr()
|
||||||
|
}
|
||||||
objs = append(objs, obj)
|
objs = append(objs, obj)
|
||||||
if isPtr {
|
elems = reflect.Append(elems, rv)
|
||||||
elems = reflect.Append(elems, rv)
|
|
||||||
} else {
|
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||||
elems = reflect.Append(elems, rv.Addr())
|
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||||
|
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||||
|
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
|
if cacheKey != "" { // has primary fields
|
||||||
|
identityMap[cacheKey] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
distinctElems = reflect.Append(distinctElems, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
@ -206,9 +221,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := utils.ToStringKey(relPrimaryValues)
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
identityMap[cacheKey] = true
|
if cacheKey != "" { // has primary fields
|
||||||
|
identityMap[cacheKey] = true
|
||||||
|
}
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
elems = reflect.Append(elems, elem)
|
elems = reflect.Append(elems, elem)
|
||||||
} else {
|
} else {
|
||||||
@ -253,6 +271,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
fieldType = reflect.PtrTo(fieldType)
|
fieldType = reflect.PtrTo(fieldType)
|
||||||
}
|
}
|
||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
|
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||||
objs := []reflect.Value{}
|
objs := []reflect.Value{}
|
||||||
|
|
||||||
@ -272,19 +291,34 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
joins = reflect.Append(joins, joinValue)
|
joins = reflect.Append(joins, joinValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
identityMap := map[string]bool{}
|
||||||
appendToElems := func(v reflect.Value) {
|
appendToElems := func(v reflect.Value) {
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||||
|
|
||||||
for i := 0; i < f.Len(); i++ {
|
for i := 0; i < f.Len(); i++ {
|
||||||
elem := f.Index(i)
|
elem := f.Index(i)
|
||||||
|
if !isPtr {
|
||||||
objs = append(objs, v)
|
elem = elem.Addr()
|
||||||
if isPtr {
|
|
||||||
elems = reflect.Append(elems, elem)
|
|
||||||
} else {
|
|
||||||
elems = reflect.Append(elems, elem.Addr())
|
|
||||||
}
|
}
|
||||||
|
objs = append(objs, v)
|
||||||
|
elems = reflect.Append(elems, elem)
|
||||||
|
|
||||||
|
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||||
|
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||||
|
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||||
|
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
|
if cacheKey != "" { // has primary fields
|
||||||
|
identityMap[cacheKey] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
distinctElems = reflect.Append(distinctElems, elem)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -304,7 +338,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
// optimize elems of reflect value length
|
// optimize elems of reflect value length
|
||||||
if elemLen := elems.Len(); elemLen > 0 {
|
if elemLen := elems.Len(); elemLen > 0 {
|
||||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||||
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < elemLen; i++ {
|
for i := 0; i < elemLen; i++ {
|
||||||
|
|||||||
@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
db.Statement.CurDestIndex = 0
|
db.Statement.CurDestIndex = 0
|
||||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||||
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
|
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||||
|
fc(value.Addr().Interface(), tx)
|
||||||
|
} else {
|
||||||
|
db.AddError(gorm.ErrInvalidValue)
|
||||||
|
return
|
||||||
|
}
|
||||||
db.Statement.CurDestIndex++
|
db.Statement.CurDestIndex++
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
if db.Statement.ReflectValue.CanAddr() {
|
||||||
|
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||||
|
} else {
|
||||||
|
db.AddError(gorm.ErrInvalidValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -102,13 +103,62 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
if db.RowsAffected == 0 {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
return
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
}
|
||||||
insertID, err := result.LastInsertId()
|
|
||||||
insertOk := err == nil && insertID > 0
|
var (
|
||||||
if !insertOk {
|
pkField *schema.Field
|
||||||
|
pkFieldName = "@id"
|
||||||
|
)
|
||||||
|
|
||||||
|
insertID, err := result.LastInsertId()
|
||||||
|
insertOk := err == nil && insertID > 0
|
||||||
|
|
||||||
|
if !insertOk {
|
||||||
|
if !supportReturning {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||||
|
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||||
|
}
|
||||||
|
|
||||||
|
// append @id column with value for auto-increment primary key
|
||||||
|
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||||
|
switch values := db.Statement.Dest.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
values[pkFieldName] = insertID
|
||||||
|
case *map[string]interface{}:
|
||||||
|
(*values)[pkFieldName] = insertID
|
||||||
|
case []map[string]interface{}, *[]map[string]interface{}:
|
||||||
|
mapValues, ok := values.([]map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||||
|
if *v != nil {
|
||||||
|
mapValues = *v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LastInsertIDReversed {
|
||||||
|
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, mapValue := range mapValues {
|
||||||
|
if mapValue != nil {
|
||||||
|
mapValue[pkFieldName] = insertID
|
||||||
|
}
|
||||||
|
insertID += schema.DefaultAutoIncrementIncrement
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if pkField == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,10 +171,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID -= pkField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -134,16 +184,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID += pkField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -252,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for field, vs := range defaultValueFieldsHavingValue {
|
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||||
for idx := range values.Values {
|
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||||
if vs[idx] == nil {
|
for idx := range values.Values {
|
||||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
if vs[idx] == nil {
|
||||||
} else {
|
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
} else {
|
||||||
|
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -302,14 +354,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
for _, column := range values.Columns {
|
for _, column := range values.Columns {
|
||||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||||
|
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||||
if field.AutoUpdateTime > 0 {
|
if field.AutoUpdateTime > 0 {
|
||||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||||
switch field.AutoUpdateTime {
|
switch field.AutoUpdateTime {
|
||||||
case schema.UnixNanosecond:
|
case schema.UnixNanosecond:
|
||||||
assignment.Value = curTime.UnixNano()
|
assignment.Value = curTime.UnixNano()
|
||||||
case schema.UnixMillisecond:
|
case schema.UnixMillisecond:
|
||||||
assignment.Value = curTime.UnixNano() / 1e6
|
assignment.Value = curTime.UnixMilli()
|
||||||
case schema.UnixSecond:
|
case schema.UnixSecond:
|
||||||
assignment.Value = curTime.Unix()
|
assignment.Value = curTime.Unix()
|
||||||
}
|
}
|
||||||
|
|||||||
71
callbacks/create_test.go
Normal file
71
callbacks/create_test.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
var schemaCache = &sync.Map{}
|
||||||
|
|
||||||
|
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
|
||||||
|
type user struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
Email string `gorm:"default:(-)"`
|
||||||
|
Age int `gorm:"default:(-)"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("parse schema error: %v, is not expected", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dest := []*user{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Name: "alice",
|
||||||
|
Email: "email",
|
||||||
|
Age: 18,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Name: "bob",
|
||||||
|
Email: "email",
|
||||||
|
Age: 19,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stmt := &gorm.Statement{
|
||||||
|
DB: &gorm.DB{
|
||||||
|
Config: &gorm.Config{
|
||||||
|
NowFunc: func() time.Time { return time.Time{} },
|
||||||
|
},
|
||||||
|
Statement: &gorm.Statement{
|
||||||
|
Settings: sync.Map{},
|
||||||
|
Schema: s,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ReflectValue: reflect.ValueOf(dest),
|
||||||
|
Dest: dest,
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.Schema = s
|
||||||
|
|
||||||
|
values := ConvertToCreateValues(stmt)
|
||||||
|
expected := clause.Values{
|
||||||
|
// column has value + defaultValue column has value (which should have a stable order)
|
||||||
|
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
|
||||||
|
Values: [][]interface{}{
|
||||||
|
{"alice", "email", 18, 1},
|
||||||
|
{"bob", "email", 19, 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(expected, values) {
|
||||||
|
t.Errorf("expected: %v got %v", expected, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
157
callbacks/helper_test.go
Normal file
157
callbacks/helper_test.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||||
|
var vm visitMap
|
||||||
|
var loaded bool
|
||||||
|
type testM struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
t1 := testM{Name: "t1"}
|
||||||
|
t2 := testM{Name: "t2"}
|
||||||
|
t3 := testM{Name: "t3"}
|
||||||
|
|
||||||
|
vm = make(visitMap)
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// t1 already exist but t2 not
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMapToValuesForCreate(t *testing.T) {
|
||||||
|
testCase := []struct {
|
||||||
|
name string
|
||||||
|
input map[string]interface{}
|
||||||
|
expect clause.Values
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test convert string value",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"name": "my name",
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "name"}},
|
||||||
|
Values: [][]interface{}{{"my name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert int value",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"age": 18,
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "age"}},
|
||||||
|
Values: [][]interface{}{{18}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert float value",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"score": 99.5,
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "score"}},
|
||||||
|
Values: [][]interface{}{{99.5}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert bool value",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "active"}},
|
||||||
|
Values: [][]interface{}{{true}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCase {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||||
|
if !reflect.DeepEqual(actual, tc.expect) {
|
||||||
|
t.Errorf("expect %v got %v", tc.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
|
||||||
|
testCase := []struct {
|
||||||
|
name string
|
||||||
|
input []map[string]interface{}
|
||||||
|
expect clause.Values
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test convert slice of string value",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"name": "my name"},
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "name"}},
|
||||||
|
Values: [][]interface{}{{"my name"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert slice of int value",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"age": 18},
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "age"}},
|
||||||
|
Values: [][]interface{}{{18}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert slice of float value",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"score": 99.5},
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "score"}},
|
||||||
|
Values: [][]interface{}{{99.5}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test convert slice of bool value",
|
||||||
|
input: []map[string]interface{}{
|
||||||
|
{"active": true},
|
||||||
|
},
|
||||||
|
expect: clause.Values{
|
||||||
|
Columns: []clause.Column{{Name: "active"}},
|
||||||
|
Values: [][]interface{}{{true}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCase {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tc.expect) {
|
||||||
|
t.Errorf("expected %v but got %v", tc.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -3,6 +3,8 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
@ -10,6 +12,164 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// parsePreloadMap extracts nested preloads. e.g.
|
||||||
|
//
|
||||||
|
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||||
|
// parsePreloadMap(schema, map[string][]interface{}{
|
||||||
|
// clause.Associations: {"arg1"},
|
||||||
|
// "k1": {"arg2"},
|
||||||
|
// "k2.k3": {"arg3"},
|
||||||
|
// "k4.k5.k6": {"arg4"},
|
||||||
|
// })
|
||||||
|
// // preloadMap is
|
||||||
|
// map[string]map[string][]interface{}{
|
||||||
|
// "k0": {},
|
||||||
|
// "k7": {
|
||||||
|
// "k8": {},
|
||||||
|
// },
|
||||||
|
// "k1": {},
|
||||||
|
// "k2": {
|
||||||
|
// "k3": {"arg3"},
|
||||||
|
// },
|
||||||
|
// "k4": {
|
||||||
|
// "k5.k6": {"arg4"},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||||
|
preloadMap := map[string]map[string][]interface{}{}
|
||||||
|
setPreloadMap := func(name, value string, args []interface{}) {
|
||||||
|
if _, ok := preloadMap[name]; !ok {
|
||||||
|
preloadMap[name] = map[string][]interface{}{}
|
||||||
|
}
|
||||||
|
if value != "" {
|
||||||
|
preloadMap[name][value] = args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, args := range preloads {
|
||||||
|
preloadFields := strings.Split(name, ".")
|
||||||
|
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||||
|
if preloadFields[0] == clause.Associations {
|
||||||
|
for _, relation := range s.Relationships.Relations {
|
||||||
|
if relation.Schema == s {
|
||||||
|
setPreloadMap(relation.Name, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||||
|
for _, value := range embeddedValues(embeddedRelations) {
|
||||||
|
setPreloadMap(embedded, value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setPreloadMap(preloadFields[0], value, args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return preloadMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||||
|
if embeddedRelations == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||||
|
for _, relation := range embeddedRelations.Relations {
|
||||||
|
// skip first struct name
|
||||||
|
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||||
|
}
|
||||||
|
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||||
|
names = append(names, embeddedValues(relations)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||||
|
// If the current relationship is embedded or joined, current query will be ignored.
|
||||||
|
//
|
||||||
|
//nolint:cyclop
|
||||||
|
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||||
|
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||||
|
|
||||||
|
// avoid random traversal of the map
|
||||||
|
preloadNames := make([]string, 0, len(preloadMap))
|
||||||
|
for key := range preloadMap {
|
||||||
|
preloadNames = append(preloadNames, key)
|
||||||
|
}
|
||||||
|
sort.Strings(preloadNames)
|
||||||
|
|
||||||
|
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||||
|
for _, join := range joins {
|
||||||
|
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||||
|
joined = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
joinNames := strings.SplitN(join, ".", 2)
|
||||||
|
if len(joinNames) == 2 {
|
||||||
|
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||||
|
joined = true
|
||||||
|
nestedJoins = append(nestedJoins, joinNames[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return joined, nestedJoins
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range preloadNames {
|
||||||
|
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||||
|
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if rel := relationships.Relations[name]; rel != nil {
|
||||||
|
if joined, nestedJoins := isJoined(name); joined {
|
||||||
|
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||||
|
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||||
|
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||||
|
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||||
|
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return gorm.ErrInvalidData
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||||
|
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||||
|
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||||
|
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||||
|
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||||
|
tx.Statement.Settings.Store(k, v)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := tx.Statement.Parse(dest); err != nil {
|
||||||
|
tx.AddError(err)
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
tx.Statement.ReflectValue = reflectValue
|
||||||
|
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||||
var (
|
var (
|
||||||
reflectValue = tx.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
|
|||||||
@ -3,11 +3,12 @@ package callbacks
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
@ -109,78 +110,141 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
specifiedRelationsName := make(map[string]interface{})
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema != nil {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
var isRelations bool // is relations or raw sql
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
var relations []*schema.Relationship
|
||||||
})
|
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
if ok {
|
||||||
tableAliasName := relation.Name
|
isRelations = true
|
||||||
|
relations = append(relations, relation)
|
||||||
|
} else {
|
||||||
|
// handle nested join like "Manager.Company"
|
||||||
|
nestedJoinNames := strings.Split(join.Name, ".")
|
||||||
|
if len(nestedJoinNames) > 1 {
|
||||||
|
isNestedJoin := true
|
||||||
|
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||||
|
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||||
|
for _, relname := range nestedJoinNames {
|
||||||
|
// incomplete match, only treated as raw sql
|
||||||
|
if relation, ok = currentRelations[relname]; ok {
|
||||||
|
gussNestedRelations = append(gussNestedRelations, relation)
|
||||||
|
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||||
|
} else {
|
||||||
|
isNestedJoin = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, s := range relation.FieldSchema.DBNames {
|
if isNestedJoin {
|
||||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
isRelations = true
|
||||||
Table: tableAliasName,
|
relations = gussNestedRelations
|
||||||
Name: s,
|
}
|
||||||
Alias: tableAliasName + "__" + s,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isRelations {
|
||||||
|
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
|
tableAliasName := relation.Name
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnStmt := gorm.Statement{
|
||||||
|
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||||
|
Selects: join.Selects, Omits: join.Omits,
|
||||||
|
}
|
||||||
|
|
||||||
|
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||||
|
for _, s := range relation.FieldSchema.DBNames {
|
||||||
|
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||||
|
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||||
|
Table: tableAliasName,
|
||||||
|
Name: s,
|
||||||
|
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
|
for idx, ref := range relation.References {
|
||||||
|
if ref.OwnPrimaryKey {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||||
|
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ref.PrimaryValue == "" {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exprs[idx] = clause.Eq{
|
||||||
|
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||||
|
Value: ref.PrimaryValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, c := range relation.FieldSchema.QueryClauses {
|
||||||
|
onStmt.AddClause(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if join.On != nil {
|
||||||
|
onStmt.AddClause(join.On)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||||
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
|
where.Build(&onStmt)
|
||||||
|
|
||||||
|
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||||
|
vars := onStmt.Vars
|
||||||
|
for idx, v := range vars {
|
||||||
|
bindvar := strings.Builder{}
|
||||||
|
onStmt.Vars = vars[0 : idx+1]
|
||||||
|
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||||
|
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return clause.Join{
|
||||||
|
Type: joinType,
|
||||||
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||||
|
ON: clause.Where{Exprs: exprs},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parentTableName := clause.CurrentTable
|
||||||
|
for _, rel := range relations {
|
||||||
|
// joins table alias like "Manager, Company, Manager__Company"
|
||||||
|
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||||
|
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||||
|
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||||
|
specifiedRelationsName[nestedAlias] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if parentTableName != clause.CurrentTable {
|
||||||
|
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||||
|
} else {
|
||||||
|
parentTableName = rel.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs := make([]clause.Expression, len(relation.References))
|
|
||||||
for idx, ref := range relation.References {
|
|
||||||
if ref.OwnPrimaryKey {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ref.PrimaryValue == "" {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
|
||||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
exprs[idx] = clause.Eq{
|
|
||||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
|
||||||
Value: ref.PrimaryValue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
|
||||||
for _, c := range relation.FieldSchema.QueryClauses {
|
|
||||||
onStmt.AddClause(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
if join.On != nil {
|
|
||||||
onStmt.AddClause(join.On)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
|
||||||
if where, ok := cs.Expression.(clause.Where); ok {
|
|
||||||
where.Build(&onStmt)
|
|
||||||
|
|
||||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
|
||||||
vars := onStmt.Vars
|
|
||||||
for idx, v := range vars {
|
|
||||||
bindvar := strings.Builder{}
|
|
||||||
onStmt.Vars = vars[0 : idx+1]
|
|
||||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
|
||||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
|
||||||
Type: clause.LeftJoin,
|
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
|
||||||
ON: clause.Where{Exprs: exprs},
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
@ -189,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClause(fromClause)
|
db.Statement.AddClause(fromClause)
|
||||||
db.Statement.Joins = nil
|
|
||||||
} else {
|
} else {
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
@ -207,60 +270,23 @@ func Preload(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadMap := map[string]map[string][]interface{}{}
|
joins := make([]string, 0, len(db.Statement.Joins))
|
||||||
for name := range db.Statement.Preloads {
|
for _, join := range db.Statement.Joins {
|
||||||
preloadFields := strings.Split(name, ".")
|
joins = append(joins, join.Name)
|
||||||
if preloadFields[0] == clause.Associations {
|
|
||||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
|
||||||
if rel.Schema == db.Statement.Schema {
|
|
||||||
if _, ok := preloadMap[rel.Name]; !ok {
|
|
||||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
|
||||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
|
||||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadNames := make([]string, 0, len(preloadMap))
|
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||||
for key := range preloadMap {
|
if tx.Error != nil {
|
||||||
preloadNames = append(preloadNames, key)
|
|
||||||
}
|
|
||||||
sort.Strings(preloadNames)
|
|
||||||
|
|
||||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
|
||||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
|
||||||
preloadDB.Statement.Settings.Store(k, v)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
|
||||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
|
||||||
} else {
|
|
||||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AfterQuery(db *gorm.DB) {
|
func AfterQuery(db *gorm.DB) {
|
||||||
|
// clear the joins after query because preload need it
|
||||||
|
db.Statement.Joins = nil
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||||
if i, ok := value.(AfterFindInterface); ok {
|
if i, ok := value.(AfterFindInterface); ok {
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import (
|
|||||||
func RowQuery(db *gorm.DB) {
|
func RowQuery(db *gorm.DB) {
|
||||||
if db.Error == nil {
|
if db.Error == nil {
|
||||||
BuildQuerySQL(db)
|
BuildQuerySQL(db)
|
||||||
if db.DryRun {
|
if db.DryRun || db.Error != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,10 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
if db.Statement.SQL.Len() == 0 {
|
if db.Statement.SQL.Len() == 0 {
|
||||||
db.Statement.SQL.Grow(180)
|
db.Statement.SQL.Grow(180)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||||
db.Statement.AddClause(set)
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
|
defer delete(db.Statement.Clauses, "SET")
|
||||||
return
|
db.Statement.AddClause(set)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
db.Statement.Build(db.Statement.BuildClauses...)
|
||||||
@ -135,7 +138,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
assignValue = func(field *schema.Field, value interface{}) {
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
if stmt.ReflectValue.CanAddr() {
|
||||||
|
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
@ -158,21 +163,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||||
var primaryKeyExprs []clause.Expression
|
var isZero bool
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
var notZero bool
|
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||||
for idx, field := range stmt.Schema.PrimaryFields {
|
if !isZero {
|
||||||
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
break
|
||||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
}
|
||||||
notZero = notZero || !isZero
|
|
||||||
}
|
|
||||||
if notZero {
|
|
||||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
if !isZero {
|
||||||
|
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||||
|
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
||||||
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
@ -229,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||||
} else {
|
} else {
|
||||||
@ -241,11 +246,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
updatingSchema := stmt.Schema
|
updatingSchema := stmt.Schema
|
||||||
|
var isDiffSchema bool
|
||||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
// different schema
|
// different schema
|
||||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||||
updatingSchema = updatingStmt.Schema
|
updatingSchema = updatingStmt.Schema
|
||||||
|
isDiffSchema = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano()
|
value = stmt.DB.NowFunc().UnixNano()
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
value = stmt.DB.NowFunc().UnixMilli()
|
||||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||||
value = stmt.DB.NowFunc().Unix()
|
value = stmt.DB.NowFunc().Unix()
|
||||||
} else {
|
} else {
|
||||||
@ -272,7 +279,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
|
|
||||||
if (ok || !isZero) && field.Updatable {
|
if (ok || !isZero) && field.Updatable {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||||
assignValue(field, value)
|
assignField := field
|
||||||
|
if isDiffSchema {
|
||||||
|
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||||
|
assignField = originField
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assignValue(assignField, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -1,36 +0,0 @@
|
|||||||
package callbacks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
|
||||||
var vm visitMap
|
|
||||||
var loaded bool
|
|
||||||
type testM struct {
|
|
||||||
Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
t1 := testM{Name: "t1"}
|
|
||||||
t2 := testM{Name: "t2"}
|
|
||||||
t3 := testM{Name: "t3"}
|
|
||||||
|
|
||||||
vm = make(visitMap)
|
|
||||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
|
||||||
t.Fatalf("loaded should be false")
|
|
||||||
}
|
|
||||||
|
|
||||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
|
||||||
t.Fatalf("loaded should be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
// t1 already exist but t2 not
|
|
||||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
|
||||||
t.Fatalf("loaded should be false")
|
|
||||||
}
|
|
||||||
|
|
||||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
|
||||||
t.Fatalf("loaded should be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
189
chainable_api.go
189
chainable_api.go
@ -10,10 +10,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Model specify the model you would like to run db operations
|
// Model specify the model you would like to run db operations
|
||||||
// // update all users's name to `hello`
|
//
|
||||||
// db.Model(&User{}).Update("name", "hello")
|
// // update all users's name to `hello`
|
||||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
// db.Model(&User{}).Update("name", "hello")
|
||||||
// db.Model(&user).Update("name", "hello")
|
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||||
|
// db.Model(&user).Update("name", "hello")
|
||||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Model = value
|
tx.Statement.Model = value
|
||||||
@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clauses Add clauses
|
// Clauses Add clauses
|
||||||
|
//
|
||||||
|
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||||
|
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||||
|
// [docs] for more depth.
|
||||||
|
//
|
||||||
|
// // add a simple limit clause
|
||||||
|
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||||
|
// // tell the optimizer to use the `idx_user_name` index
|
||||||
|
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||||
|
// // specify the lock strength to UPDATE
|
||||||
|
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||||
|
//
|
||||||
|
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
var whereConds []interface{}
|
var whereConds []interface{}
|
||||||
@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||||
|
|
||||||
// Table specify the table you would like to run db operations
|
// Table specify the table you would like to run db operations
|
||||||
|
//
|
||||||
|
// // Get a user
|
||||||
|
// db.Table("users").Take(&result)
|
||||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||||
tx.Statement.Table = results[1]
|
if results[1] != "" {
|
||||||
|
tx.Statement.Table = results[1]
|
||||||
|
} else {
|
||||||
|
tx.Statement.Table = results[2]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||||
@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Distinct specify distinct fields that you want querying
|
// Distinct specify distinct fields that you want querying
|
||||||
|
//
|
||||||
|
// // Select distinct names of users
|
||||||
|
// db.Distinct("name").Find(&results)
|
||||||
|
// // Select distinct name/age pairs from users
|
||||||
|
// db.Distinct("name", "age").Find(&results)
|
||||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Distinct = true
|
tx.Statement.Distinct = true
|
||||||
@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Select specify fields that you want when querying, creating, updating
|
// Select specify fields that you want when querying, creating, updating
|
||||||
|
//
|
||||||
|
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||||
|
// Select accepts both string arguments and arrays.
|
||||||
|
//
|
||||||
|
// // Select name and age of user using multiple arguments
|
||||||
|
// db.Select("name", "age").Find(&users)
|
||||||
|
// // Select name and age of user using an array
|
||||||
|
// db.Select([]string{"name", "age"}).Find(&users)
|
||||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
@ -152,6 +186,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Where add conditions
|
// Where add conditions
|
||||||
|
//
|
||||||
|
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||||
|
//
|
||||||
|
// // Find the first user with name jinzhu
|
||||||
|
// db.Where("name = ?", "jinzhu").First(&user)
|
||||||
|
// // Find the first user with name jinzhu and age 20
|
||||||
|
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||||
|
// // Find the first user with name jinzhu and age not equal to 20
|
||||||
|
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||||
|
//
|
||||||
|
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||||
@ -161,6 +206,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Not add NOT conditions
|
// Not add NOT conditions
|
||||||
|
//
|
||||||
|
// Not works similarly to where, and has the same syntax.
|
||||||
|
//
|
||||||
|
// // Find the first user with name not equal to jinzhu
|
||||||
|
// db.Not("name = ?", "jinzhu").First(&user)
|
||||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||||
@ -170,6 +220,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Or add OR conditions
|
// Or add OR conditions
|
||||||
|
//
|
||||||
|
// Or is used to chain together queries with an OR.
|
||||||
|
//
|
||||||
|
// // Find the first user with name equal to jinzhu or john
|
||||||
|
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||||
@ -179,26 +234,45 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Joins specify Joins conditions
|
// Joins specify Joins conditions
|
||||||
// db.Joins("Account").Find(&user)
|
//
|
||||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
// db.Joins("Account").Find(&user)
|
||||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
|
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||||
|
return joins(db, clause.LeftJoin, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InnerJoins specify inner joins conditions
|
||||||
|
// db.InnerJoins("Account").Find(&user)
|
||||||
|
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||||
|
return joins(db, clause.InnerJoin, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
if db, ok := args[0].(*DB); ok {
|
if db, ok := args[0].(*DB); ok {
|
||||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
j := join{
|
||||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
|
Name: query, Conds: args, Selects: db.Statement.Selects,
|
||||||
return
|
Omits: db.Statement.Omits, JoinType: joinType,
|
||||||
}
|
}
|
||||||
|
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||||
|
j.On = &where
|
||||||
|
}
|
||||||
|
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group specify the group method on the find
|
// Group specify the group method on the find
|
||||||
|
//
|
||||||
|
// // Select the sum age of users with given names
|
||||||
|
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||||
func (db *DB) Group(name string) (tx *DB) {
|
func (db *DB) Group(name string) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
@ -210,6 +284,9 @@ func (db *DB) Group(name string) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Having specify HAVING conditions for GROUP BY
|
// Having specify HAVING conditions for GROUP BY
|
||||||
|
//
|
||||||
|
// // Select the sum age of users with name jinzhu
|
||||||
|
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.AddClause(clause.GroupBy{
|
tx.Statement.AddClause(clause.GroupBy{
|
||||||
@ -218,9 +295,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Order specify order when retrieve records from database
|
// Order specify order when retrieving records from database
|
||||||
// db.Order("name DESC")
|
//
|
||||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
// db.Order("name DESC")
|
||||||
|
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
@ -242,13 +320,27 @@ func (db *DB) Order(value interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Limit specify the number of records to be retrieved
|
// Limit specify the number of records to be retrieved
|
||||||
|
//
|
||||||
|
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||||
|
//
|
||||||
|
// // retrieve 3 users
|
||||||
|
// db.Limit(3).Find(&users)
|
||||||
|
// // retrieve 3 users into users1, and all users into users2
|
||||||
|
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||||
func (db *DB) Limit(limit int) (tx *DB) {
|
func (db *DB) Limit(limit int) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Offset specify the number of records to skip before starting to return the records
|
// Offset specify the number of records to skip before starting to return the records
|
||||||
|
//
|
||||||
|
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||||
|
//
|
||||||
|
// // select the third user
|
||||||
|
// db.Offset(2).First(&user)
|
||||||
|
// // select the first user by cancelling an earlier chained offset
|
||||||
|
// db.Offset(5).Offset(-1).First(&user)
|
||||||
func (db *DB) Offset(offset int) (tx *DB) {
|
func (db *DB) Offset(offset int) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||||
@ -256,25 +348,37 @@ func (db *DB) Offset(offset int) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
|
||||||
// return db.Where("amount > ?", 1000)
|
|
||||||
// }
|
|
||||||
//
|
//
|
||||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||||
// return func (db *gorm.DB) *gorm.DB {
|
// return db.Where("amount > ?", 1000)
|
||||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
// }
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
//
|
||||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) executeScopes() (tx *DB) {
|
||||||
|
scopes := db.Statement.scopes
|
||||||
|
db.Statement.scopes = nil
|
||||||
|
for _, scope := range scopes {
|
||||||
|
db = scope(db)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
// Preload preload associations with given conditions
|
// Preload preload associations with given conditions
|
||||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
//
|
||||||
|
// // get all users, and preload all non-cancelled orders
|
||||||
|
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if tx.Statement.Preloads == nil {
|
if tx.Statement.Preloads == nil {
|
||||||
@ -284,12 +388,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||||
|
//
|
||||||
|
// Attrs only adds attributes if the record is not found.
|
||||||
|
//
|
||||||
|
// // assign an email if the record is not found
|
||||||
|
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||||
|
//
|
||||||
|
// // assign an email if the record is not found, otherwise ignore provided email
|
||||||
|
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "jinzhu", Age: 20}
|
||||||
|
//
|
||||||
|
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||||
|
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.attrs = attrs
|
tx.Statement.attrs = attrs
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||||
|
//
|
||||||
|
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||||
|
// records will be updated even if they are found.
|
||||||
|
//
|
||||||
|
// // assign an email regardless of if the record is not found
|
||||||
|
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||||
|
//
|
||||||
|
// // assign email regardless of if record is found
|
||||||
|
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||||
|
//
|
||||||
|
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||||
|
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.assigns = attrs
|
tx.Statement.assigns = attrs
|
||||||
|
|||||||
@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
|
|||||||
func BenchmarkComplexSelect(b *testing.B) {
|
func BenchmarkComplexSelect(b *testing.B) {
|
||||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||||
|
|
||||||
|
limit10 := 10
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||||
clauses := []clause.Interface{
|
clauses := []clause.Interface{
|
||||||
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
|
|||||||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||||
}},
|
}},
|
||||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||||
clause.Limit{Limit: 10, Offset: 20},
|
clause.Limit{Limit: &limit10, Offset: 20},
|
||||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ type Builder interface {
|
|||||||
Writer
|
Writer
|
||||||
WriteQuoted(field interface{})
|
WriteQuoted(field interface{})
|
||||||
AddVar(Writer, ...interface{})
|
AddVar(Writer, ...interface{})
|
||||||
|
AddError(error) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clause
|
// Clause
|
||||||
|
|||||||
@ -126,8 +126,8 @@ func (expr NamedExpr) Build(builder Builder) {
|
|||||||
for _, v := range []byte(expr.SQL) {
|
for _, v := range []byte(expr.SQL) {
|
||||||
if v == '@' && !inName {
|
if v == '@' && !inName {
|
||||||
inName = true
|
inName = true
|
||||||
name = []byte{}
|
name = name[:0]
|
||||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' {
|
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||||
if inName {
|
if inName {
|
||||||
if nv, ok := namedMap[string(name)]; ok {
|
if nv, ok := namedMap[string(name)]; ok {
|
||||||
builder.AddVar(builder, nv)
|
builder.AddVar(builder, nv)
|
||||||
@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) {
|
|||||||
|
|
||||||
switch eq.Value.(type) {
|
switch eq.Value.(type) {
|
||||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||||
builder.WriteString(" IN (")
|
|
||||||
rv := reflect.ValueOf(eq.Value)
|
rv := reflect.ValueOf(eq.Value)
|
||||||
for i := 0; i < rv.Len(); i++ {
|
if rv.Len() == 0 {
|
||||||
if i > 0 {
|
builder.WriteString(" IN (NULL)")
|
||||||
builder.WriteByte(',')
|
} else {
|
||||||
|
builder.WriteString(" IN (")
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
builder.AddVar(builder, rv.Index(i).Interface())
|
||||||
}
|
}
|
||||||
builder.AddVar(builder, rv.Index(i).Interface())
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
builder.WriteByte(')')
|
|
||||||
default:
|
default:
|
||||||
if eqNil(eq.Value) {
|
if eqNil(eq.Value) {
|
||||||
builder.WriteString(" IS NULL")
|
builder.WriteString(" IS NULL")
|
||||||
|
|||||||
@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) {
|
|||||||
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
||||||
Result: "name1 = ? AND name2 = ?;",
|
Result: "name1 = ? AND name2 = ?;",
|
||||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||||
|
}, {
|
||||||
|
SQL: "name1 = @name1\r\n AND name2 = @name2",
|
||||||
|
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
|
||||||
|
Result: "name1 = ?\r\n AND name2 = ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||||
|
}, {
|
||||||
|
SQL: "name1 = @name1\r AND name2 = @name2",
|
||||||
|
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
|
||||||
|
Result: "name1 = ?\r AND name2 = ?",
|
||||||
|
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||||
}, {
|
}, {
|
||||||
SQL: "?",
|
SQL: "?",
|
||||||
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
|
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
|
||||||
@ -189,6 +199,11 @@ func TestExpression(t *testing.T) {
|
|||||||
},
|
},
|
||||||
ExpectedVars: []interface{}{"a", "b"},
|
ExpectedVars: []interface{}{"a", "b"},
|
||||||
Result: "`column-name` NOT IN (?,?)",
|
Result: "`column-name` NOT IN (?,?)",
|
||||||
|
}, {
|
||||||
|
Expressions: []clause.Expression{
|
||||||
|
clause.Eq{Column: column, Value: []string{}},
|
||||||
|
},
|
||||||
|
Result: "`column-name` IN (NULL)",
|
||||||
}, {
|
}, {
|
||||||
Expressions: []clause.Expression{
|
Expressions: []clause.Expression{
|
||||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||||
|
|||||||
@ -9,7 +9,7 @@ const (
|
|||||||
RightJoin JoinType = "RIGHT"
|
RightJoin JoinType = "RIGHT"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Join join clause for from
|
// Join clause for from
|
||||||
type Join struct {
|
type Join struct {
|
||||||
Type JoinType
|
Type JoinType
|
||||||
Table Table
|
Table Table
|
||||||
|
|||||||
101
clause/joins_test.go
Normal file
101
clause/joins_test.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package clause_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJoin(t *testing.T) {
|
||||||
|
results := []struct {
|
||||||
|
name string
|
||||||
|
join clause.Join
|
||||||
|
sql string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LEFT JOIN",
|
||||||
|
join: clause.Join{
|
||||||
|
Type: clause.LeftJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
ON: clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RIGHT JOIN",
|
||||||
|
join: clause.Join{
|
||||||
|
Type: clause.RightJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
ON: clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INNER JOIN",
|
||||||
|
join: clause.Join{
|
||||||
|
Type: clause.InnerJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
ON: clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CROSS JOIN",
|
||||||
|
join: clause.Join{
|
||||||
|
Type: clause.CrossJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
ON: clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "USING",
|
||||||
|
join: clause.Join{
|
||||||
|
Type: clause.InnerJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
Using: []string{"id"},
|
||||||
|
},
|
||||||
|
sql: "INNER JOIN `user` USING (`id`)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Expression",
|
||||||
|
join: clause.Join{
|
||||||
|
// Invalid
|
||||||
|
Type: clause.LeftJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
ON: clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||||
|
},
|
||||||
|
// Valid
|
||||||
|
Expression: clause.Join{
|
||||||
|
Type: clause.InnerJoin,
|
||||||
|
Table: clause.Table{Name: "user"},
|
||||||
|
Using: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sql: "INNER JOIN `user` USING (`id`)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
t.Run(result.name, func(t *testing.T) {
|
||||||
|
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||||
|
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||||
|
result.join.Build(stmt)
|
||||||
|
if result.sql != stmt.SQL.String() {
|
||||||
|
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,10 +1,8 @@
|
|||||||
package clause
|
package clause
|
||||||
|
|
||||||
import "strconv"
|
|
||||||
|
|
||||||
// Limit limit clause
|
// Limit limit clause
|
||||||
type Limit struct {
|
type Limit struct {
|
||||||
Limit int
|
Limit *int
|
||||||
Offset int
|
Offset int
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -15,16 +13,16 @@ func (limit Limit) Name() string {
|
|||||||
|
|
||||||
// Build build where clause
|
// Build build where clause
|
||||||
func (limit Limit) Build(builder Builder) {
|
func (limit Limit) Build(builder Builder) {
|
||||||
if limit.Limit > 0 {
|
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||||
builder.WriteString("LIMIT ")
|
builder.WriteString("LIMIT ")
|
||||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
builder.AddVar(builder, *limit.Limit)
|
||||||
}
|
}
|
||||||
if limit.Offset > 0 {
|
if limit.Offset > 0 {
|
||||||
if limit.Limit > 0 {
|
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||||
builder.WriteByte(' ')
|
builder.WriteByte(' ')
|
||||||
}
|
}
|
||||||
builder.WriteString("OFFSET ")
|
builder.WriteString("OFFSET ")
|
||||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
builder.AddVar(builder, limit.Offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,7 +31,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
|||||||
clause.Name = ""
|
clause.Name = ""
|
||||||
|
|
||||||
if v, ok := clause.Expression.(Limit); ok {
|
if v, ok := clause.Expression.(Limit); ok {
|
||||||
if limit.Limit == 0 && v.Limit != 0 {
|
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||||
limit.Limit = v.Limit
|
limit.Limit = v.Limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestLimit(t *testing.T) {
|
func TestLimit(t *testing.T) {
|
||||||
|
limit0 := 0
|
||||||
|
limit10 := 10
|
||||||
|
limit50 := 50
|
||||||
|
limitNeg10 := -10
|
||||||
results := []struct {
|
results := []struct {
|
||||||
Clauses []clause.Interface
|
Clauses []clause.Interface
|
||||||
Result string
|
Result string
|
||||||
@ -15,38 +19,56 @@ func TestLimit(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||||
Limit: 10,
|
Limit: &limit10,
|
||||||
Offset: 20,
|
Offset: 20,
|
||||||
}},
|
}},
|
||||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||||
|
[]interface{}{limit10, 20},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||||
|
"SELECT * FROM `users` LIMIT ?",
|
||||||
|
[]interface{}{limit0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
|
||||||
|
"SELECT * FROM `users` LIMIT ?",
|
||||||
|
[]interface{}{limit0},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||||
"SELECT * FROM `users` OFFSET 20", nil,
|
"SELECT * FROM `users` OFFSET ?",
|
||||||
|
[]interface{}{20},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
|
||||||
"SELECT * FROM `users` OFFSET 30", nil,
|
"SELECT * FROM `users` OFFSET ?",
|
||||||
|
[]interface{}{30},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
|
||||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||||
|
[]interface{}{limit10, 20},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||||
|
[]interface{}{limit10, 30},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||||
"SELECT * FROM `users` LIMIT 10", nil,
|
"SELECT * FROM `users` LIMIT ?",
|
||||||
|
[]interface{}{limit10},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
|
||||||
"SELECT * FROM `users` OFFSET 30", nil,
|
"SELECT * FROM `users` OFFSET ?",
|
||||||
|
[]interface{}{30},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
|
||||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||||
|
[]interface{}{limit50, 30},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,12 @@
|
|||||||
package clause
|
package clause
|
||||||
|
|
||||||
|
const (
|
||||||
|
LockingStrengthUpdate = "UPDATE"
|
||||||
|
LockingStrengthShare = "SHARE"
|
||||||
|
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||||
|
LockingOptionsNoWait = "NOWAIT"
|
||||||
|
)
|
||||||
|
|
||||||
type Locking struct {
|
type Locking struct {
|
||||||
Strength string
|
Strength string
|
||||||
Table Table
|
Table Table
|
||||||
|
|||||||
@ -14,17 +14,21 @@ func TestLocking(t *testing.T) {
|
|||||||
Vars []interface{}
|
Vars []interface{}
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
|
||||||
"SELECT * FROM `users` FOR UPDATE", nil,
|
"SELECT * FROM `users` FOR UPDATE", nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
|
||||||
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
|
||||||
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
|
||||||
|
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, result := range results {
|
for idx, result := range results {
|
||||||
|
|||||||
@ -16,27 +16,27 @@ func (OnConflict) Name() string {
|
|||||||
|
|
||||||
// Build build onConflict clause
|
// Build build onConflict clause
|
||||||
func (onConflict OnConflict) Build(builder Builder) {
|
func (onConflict OnConflict) Build(builder Builder) {
|
||||||
if len(onConflict.Columns) > 0 {
|
|
||||||
builder.WriteByte('(')
|
|
||||||
for idx, column := range onConflict.Columns {
|
|
||||||
if idx > 0 {
|
|
||||||
builder.WriteByte(',')
|
|
||||||
}
|
|
||||||
builder.WriteQuoted(column)
|
|
||||||
}
|
|
||||||
builder.WriteString(`) `)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
|
||||||
builder.WriteString(" WHERE ")
|
|
||||||
onConflict.TargetWhere.Build(builder)
|
|
||||||
builder.WriteByte(' ')
|
|
||||||
}
|
|
||||||
|
|
||||||
if onConflict.OnConstraint != "" {
|
if onConflict.OnConstraint != "" {
|
||||||
builder.WriteString("ON CONSTRAINT ")
|
builder.WriteString("ON CONSTRAINT ")
|
||||||
builder.WriteString(onConflict.OnConstraint)
|
builder.WriteString(onConflict.OnConstraint)
|
||||||
builder.WriteByte(' ')
|
builder.WriteByte(' ')
|
||||||
|
} else {
|
||||||
|
if len(onConflict.Columns) > 0 {
|
||||||
|
builder.WriteByte('(')
|
||||||
|
for idx, column := range onConflict.Columns {
|
||||||
|
if idx > 0 {
|
||||||
|
builder.WriteByte(',')
|
||||||
|
}
|
||||||
|
builder.WriteQuoted(column)
|
||||||
|
}
|
||||||
|
builder.WriteString(`) `)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||||
|
builder.WriteString(" WHERE ")
|
||||||
|
onConflict.TargetWhere.Build(builder)
|
||||||
|
builder.WriteByte(' ')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if onConflict.DoNothing {
|
if onConflict.DoNothing {
|
||||||
|
|||||||
@ -49,16 +49,18 @@ func TestSelect(t *testing.T) {
|
|||||||
Exprs: []clause.Expression{
|
Exprs: []clause.Expression{
|
||||||
clause.Expr{
|
clause.Expr{
|
||||||
SQL: "? as name",
|
SQL: "? as name",
|
||||||
Vars: []interface{}{clause.Eq{
|
Vars: []interface{}{
|
||||||
Column: clause.Column{Name: "age"},
|
clause.Eq{
|
||||||
Value: 18,
|
Column: clause.Column{Name: "age"},
|
||||||
},
|
Value: 18,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, clause.From{}},
|
}, clause.From{}},
|
||||||
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
|
"SELECT `age` = ? as name FROM `users`",
|
||||||
|
[]interface{}{18},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,12 @@ func (where Where) Name() string {
|
|||||||
|
|
||||||
// Build build where clause
|
// Build build where clause
|
||||||
func (where Where) Build(builder Builder) {
|
func (where Where) Build(builder Builder) {
|
||||||
|
if len(where.Exprs) == 1 {
|
||||||
|
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||||
|
where.Exprs = andCondition.Exprs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Switch position if the first query expression is a single Or condition
|
// Switch position if the first query expression is a single Or condition
|
||||||
for idx, expr := range where.Exprs {
|
for idx, expr := range where.Exprs {
|
||||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||||
@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression {
|
|||||||
if len(exprs) == 0 {
|
if len(exprs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if len(exprs) == 1 {
|
||||||
|
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||||
|
exprs = andCondition.Exprs
|
||||||
|
}
|
||||||
|
}
|
||||||
return NotConditions{Exprs: exprs}
|
return NotConditions{Exprs: exprs}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,19 +166,58 @@ type NotConditions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (not NotConditions) Build(builder Builder) {
|
func (not NotConditions) Build(builder Builder) {
|
||||||
if len(not.Exprs) > 1 {
|
anyNegationBuilder := false
|
||||||
builder.WriteByte('(')
|
for _, c := range not.Exprs {
|
||||||
|
if _, ok := c.(NegationExpressionBuilder); ok {
|
||||||
|
anyNegationBuilder = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, c := range not.Exprs {
|
if anyNegationBuilder {
|
||||||
if idx > 0 {
|
if len(not.Exprs) > 1 {
|
||||||
builder.WriteString(AndWithSpace)
|
builder.WriteByte('(')
|
||||||
}
|
}
|
||||||
|
|
||||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
for idx, c := range not.Exprs {
|
||||||
negationBuilder.NegationBuild(builder)
|
if idx > 0 {
|
||||||
} else {
|
builder.WriteString(AndWithSpace)
|
||||||
builder.WriteString("NOT ")
|
}
|
||||||
|
|
||||||
|
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||||
|
negationBuilder.NegationBuild(builder)
|
||||||
|
} else {
|
||||||
|
builder.WriteString("NOT ")
|
||||||
|
e, wrapInParentheses := c.(Expr)
|
||||||
|
if wrapInParentheses {
|
||||||
|
sql := strings.ToUpper(e.SQL)
|
||||||
|
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||||
|
builder.WriteByte('(')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Build(builder)
|
||||||
|
|
||||||
|
if wrapInParentheses {
|
||||||
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(not.Exprs) > 1 {
|
||||||
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
builder.WriteString("NOT ")
|
||||||
|
if len(not.Exprs) > 1 {
|
||||||
|
builder.WriteByte('(')
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, c := range not.Exprs {
|
||||||
|
if idx > 0 {
|
||||||
|
builder.WriteString(AndWithSpace)
|
||||||
|
}
|
||||||
|
|
||||||
e, wrapInParentheses := c.(Expr)
|
e, wrapInParentheses := c.(Expr)
|
||||||
if wrapInParentheses {
|
if wrapInParentheses {
|
||||||
sql := strings.ToUpper(e.SQL)
|
sql := strings.ToUpper(e.SQL)
|
||||||
@ -182,9 +232,9 @@ func (not NotConditions) Build(builder Builder) {
|
|||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(not.Exprs) > 1 {
|
if len(not.Exprs) > 1 {
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
|
|||||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
||||||
}},
|
}},
|
||||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
|
||||||
[]interface{}{18, "jinzhu"},
|
[]interface{}{18, "jinzhu"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
|
|||||||
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
|
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
|
||||||
[]interface{}{"1", 100},
|
[]interface{}{"1", 100},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -105,6 +105,14 @@ func TestWhere(t *testing.T) {
|
|||||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
|
||||||
[]interface{}{"1", 100},
|
[]interface{}{"1", 100},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
|
||||||
|
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
|
||||||
|
[]interface{}{100, 60},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, result := range results {
|
for idx, result := range results {
|
||||||
|
|||||||
@ -21,6 +21,10 @@ var (
|
|||||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||||
// ErrModelValueRequired model value required
|
// ErrModelValueRequired model value required
|
||||||
ErrModelValueRequired = errors.New("model value required")
|
ErrModelValueRequired = errors.New("model value required")
|
||||||
|
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||||
|
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||||
|
// ErrSubQueryRequired sub query required
|
||||||
|
ErrSubQueryRequired = errors.New("sub query required")
|
||||||
// ErrInvalidData unsupported data
|
// ErrInvalidData unsupported data
|
||||||
ErrInvalidData = errors.New("unsupported data")
|
ErrInvalidData = errors.New("unsupported data")
|
||||||
// ErrUnsupportedDriver unsupported driver
|
// ErrUnsupportedDriver unsupported driver
|
||||||
@ -41,4 +45,8 @@ var (
|
|||||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||||
|
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||||
|
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||||
|
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||||
|
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||||
)
|
)
|
||||||
|
|||||||
222
finisher_api.go
222
finisher_api.go
@ -13,7 +13,7 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create insert the value into database
|
// Create inserts value, returning the inserted data's primary key in value's id
|
||||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||||
if db.CreateBatchSize > 0 {
|
if db.CreateBatchSize > 0 {
|
||||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||||
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Create().Execute(tx)
|
return tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateInBatches insert the value in batches into database
|
// CreateInBatches inserts value in batches of batchSize
|
||||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
|
||||||
@ -33,9 +33,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
var rowsAffected int64
|
var rowsAffected int64
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
||||||
|
// the reflection length judgment of the optimized value
|
||||||
|
reflectLen := reflectValue.Len()
|
||||||
|
|
||||||
callFc := func(tx *DB) error {
|
callFc := func(tx *DB) error {
|
||||||
// the reflection length judgment of the optimized value
|
|
||||||
reflectLen := reflectValue.Len()
|
|
||||||
for i := 0; i < reflectLen; i += batchSize {
|
for i := 0; i < reflectLen; i += batchSize {
|
||||||
ends := i + batchSize
|
ends := i + batchSize
|
||||||
if ends > reflectLen {
|
if ends > reflectLen {
|
||||||
@ -53,7 +54,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if tx.SkipDefaultTransaction {
|
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
|
||||||
tx.AddError(callFc(tx.Session(&Session{})))
|
tx.AddError(callFc(tx.Session(&Session{})))
|
||||||
} else {
|
} else {
|
||||||
tx.AddError(tx.Transaction(callFc))
|
tx.AddError(tx.Transaction(callFc))
|
||||||
@ -68,7 +69,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
||||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
@ -101,20 +102,19 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||||
}
|
}
|
||||||
|
|
||||||
tx = tx.callbacks.Update().Execute(tx)
|
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||||
|
|
||||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||||
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
|
|
||||||
return tx.Create(value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return updateTx
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// First find first record that match given conditions, order by primary key
|
// First finds the first record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
||||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1)
|
tx = db.Limit(1)
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last find last record that match given conditions, order by primary key
|
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find find records that match given conditions
|
// Find finds all records matching given conditions conds
|
||||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindInBatches find records in batches
|
// FindInBatches finds all records in batches of batchSize
|
||||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||||
var (
|
var (
|
||||||
tx = db.Order(clause.OrderByColumn{
|
tx = db.Order(clause.OrderByColumn{
|
||||||
@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
var totalSize int
|
var totalSize int
|
||||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||||
totalSize = limit.Limit
|
if limit.Limit != nil {
|
||||||
|
totalSize = *limit.Limit
|
||||||
|
}
|
||||||
|
|
||||||
if totalSize > 0 && batchSize > totalSize {
|
if totalSize > 0 && batchSize > totalSize {
|
||||||
batchSize = totalSize
|
batchSize = totalSize
|
||||||
@ -202,7 +204,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
batch++
|
batch++
|
||||||
|
|
||||||
if result.Error == nil && result.RowsAffected != 0 {
|
if result.Error == nil && result.RowsAffected != 0 {
|
||||||
tx.AddError(fc(result, batch))
|
fcTx := result.Session(&Session{NewDB: true})
|
||||||
|
fcTx.RowsAffected = result.RowsAffected
|
||||||
|
tx.AddError(fc(fcTx, batch))
|
||||||
} else if result.Error != nil {
|
} else if result.Error != nil {
|
||||||
tx.AddError(result.Error)
|
tx.AddError(result.Error)
|
||||||
}
|
}
|
||||||
@ -227,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
|
if zero {
|
||||||
|
tx.AddError(ErrPrimaryKeyRequired)
|
||||||
|
break
|
||||||
|
}
|
||||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +292,18 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||||
|
// Each conds must be a struct or map.
|
||||||
|
//
|
||||||
|
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||||
|
//
|
||||||
|
// // assign an email if the record is not found
|
||||||
|
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||||
|
//
|
||||||
|
// // assign email regardless of if record is found
|
||||||
|
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||||
|
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -310,62 +329,82 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
|
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||||
|
// Each conds must be a struct or map.
|
||||||
|
//
|
||||||
|
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||||
|
//
|
||||||
|
// // assign an email if the record is not found
|
||||||
|
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||||
|
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||||
|
// // result.RowsAffected -> 1
|
||||||
|
//
|
||||||
|
// // assign email regardless of if record is found
|
||||||
|
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||||
|
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||||
|
// // result.RowsAffected -> 1
|
||||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
})
|
})
|
||||||
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
|
||||||
if result.RowsAffected == 0 {
|
|
||||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
|
||||||
result.assignInterfacesToValue(where.Exprs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize with attrs, conds
|
result := queryTx.Find(dest, conds...)
|
||||||
if len(db.Statement.attrs) > 0 {
|
if result.Error != nil {
|
||||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
tx.Error = result.Error
|
||||||
}
|
return tx
|
||||||
|
|
||||||
// initialize with attrs, conds
|
|
||||||
if len(db.Statement.assigns) > 0 {
|
|
||||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Create(dest)
|
|
||||||
} else if len(db.Statement.assigns) > 0 {
|
|
||||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
|
||||||
assigns := map[string]interface{}{}
|
|
||||||
for _, expr := range exprs {
|
|
||||||
if eq, ok := expr.(clause.Eq); ok {
|
|
||||||
switch column := eq.Column.(type) {
|
|
||||||
case string:
|
|
||||||
assigns[column] = eq.Value
|
|
||||||
case clause.Column:
|
|
||||||
assigns[column.Name] = eq.Value
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Model(dest).Updates(assigns)
|
|
||||||
} else {
|
|
||||||
tx.Error = result.Error
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||||
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
|
result.assignInterfacesToValue(where.Exprs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize with attrs, conds
|
||||||
|
if len(db.Statement.attrs) > 0 {
|
||||||
|
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize with attrs, conds
|
||||||
|
if len(db.Statement.assigns) > 0 {
|
||||||
|
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Create(dest)
|
||||||
|
} else if len(db.Statement.assigns) > 0 {
|
||||||
|
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||||
|
assigns := map[string]interface{}{}
|
||||||
|
for i := 0; i < len(exprs); i++ {
|
||||||
|
expr := exprs[i]
|
||||||
|
|
||||||
|
if eq, ok := expr.(clause.AndConditions); ok {
|
||||||
|
exprs = append(exprs, eq.Exprs...)
|
||||||
|
} else if eq, ok := expr.(clause.Eq); ok {
|
||||||
|
switch column := eq.Column.(type) {
|
||||||
|
case string:
|
||||||
|
assigns[column] = eq.Value
|
||||||
|
case clause.Column:
|
||||||
|
assigns[column.Name] = eq.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Model(dest).Updates(assigns)
|
||||||
|
}
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||||
return tx.callbacks.Update().Execute(tx)
|
return tx.callbacks.Update().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = values
|
tx.Statement.Dest = values
|
||||||
@ -386,7 +425,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Update().Execute(tx)
|
return tx.callbacks.Update().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
||||||
|
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
||||||
|
// time if null.
|
||||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -480,7 +521,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
|
|||||||
return rows, tx.Error
|
return rows, tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scans selected value to the struct dest
|
||||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||||
config := *db.Config
|
config := *db.Config
|
||||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||||
@ -494,6 +535,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||||||
tx.ScanRows(rows, dest)
|
tx.ScanRows(rows, dest)
|
||||||
} else {
|
} else {
|
||||||
tx.RowsAffected = 0
|
tx.RowsAffected = 0
|
||||||
|
tx.AddError(rows.Err())
|
||||||
}
|
}
|
||||||
tx.AddError(rows.Close())
|
tx.AddError(rows.Close())
|
||||||
}
|
}
|
||||||
@ -505,9 +547,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pluck used to query single column from a model as a map
|
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||||
// var ages []int64
|
//
|
||||||
// db.Model(&users).Pluck("age", &ages)
|
// var ages []int64
|
||||||
|
// db.Model(&users).Pluck("age", &ages)
|
||||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if tx.Statement.Model != nil {
|
if tx.Statement.Model != nil {
|
||||||
@ -548,7 +591,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
|
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
||||||
|
// returned to the connection pool.
|
||||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
return db.Error
|
return db.Error
|
||||||
@ -570,7 +614,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
|||||||
return fc(tx)
|
return fc(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||||
|
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||||
|
// they are rolled back.
|
||||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||||
panicked := true
|
panicked := true
|
||||||
|
|
||||||
@ -581,7 +627,6 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// Make sure to rollback when panic, Block error or Commit error
|
// Make sure to rollback when panic, Block error or Commit error
|
||||||
if panicked || err != nil {
|
if panicked || err != nil {
|
||||||
@ -589,8 +634,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
|
||||||
err = fc(db.Session(&Session{}))
|
|
||||||
} else {
|
} else {
|
||||||
tx := db.Begin(opts...)
|
tx := db.Begin(opts...)
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
@ -614,7 +658,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin begins a transaction
|
// Begin begins a transaction with any transaction options opts
|
||||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
var (
|
var (
|
||||||
// clone statement
|
// clone statement
|
||||||
@ -643,7 +687,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commits the changes in a transaction
|
||||||
func (db *DB) Commit() *DB {
|
func (db *DB) Commit() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||||
db.AddError(committer.Commit())
|
db.AddError(committer.Commit())
|
||||||
@ -653,7 +697,7 @@ func (db *DB) Commit() *DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback rollback a transaction
|
// Rollback rollbacks the changes in a transaction
|
||||||
func (db *DB) Rollback() *DB {
|
func (db *DB) Rollback() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
if !reflect.ValueOf(committer).IsNil() {
|
if !reflect.ValueOf(committer).IsNil() {
|
||||||
@ -667,7 +711,21 @@ func (db *DB) Rollback() *DB {
|
|||||||
|
|
||||||
func (db *DB) SavePoint(name string) *DB {
|
func (db *DB) SavePoint(name string) *DB {
|
||||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||||
|
// close prepared statement, because SavePoint not support prepared statement.
|
||||||
|
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||||
|
var (
|
||||||
|
preparedStmtTx *PreparedStmtTX
|
||||||
|
isPreparedStmtTx bool
|
||||||
|
)
|
||||||
|
// close prepared statement, because SavePoint not support prepared statement.
|
||||||
|
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||||
|
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||||
|
}
|
||||||
db.AddError(savePointer.SavePoint(db, name))
|
db.AddError(savePointer.SavePoint(db, name))
|
||||||
|
// restore prepared statement
|
||||||
|
if isPreparedStmtTx {
|
||||||
|
db.Statement.ConnPool = preparedStmtTx
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrUnsupportedDriver)
|
db.AddError(ErrUnsupportedDriver)
|
||||||
}
|
}
|
||||||
@ -676,14 +734,28 @@ func (db *DB) SavePoint(name string) *DB {
|
|||||||
|
|
||||||
func (db *DB) RollbackTo(name string) *DB {
|
func (db *DB) RollbackTo(name string) *DB {
|
||||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||||
|
// close prepared statement, because RollbackTo not support prepared statement.
|
||||||
|
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||||
|
var (
|
||||||
|
preparedStmtTx *PreparedStmtTX
|
||||||
|
isPreparedStmtTx bool
|
||||||
|
)
|
||||||
|
// close prepared statement, because SavePoint not support prepared statement.
|
||||||
|
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||||
|
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||||
|
}
|
||||||
db.AddError(savePointer.RollbackTo(db, name))
|
db.AddError(savePointer.RollbackTo(db, name))
|
||||||
|
// restore prepared statement
|
||||||
|
if isPreparedStmtTx {
|
||||||
|
db.Statement.ConnPool = preparedStmtTx
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrUnsupportedDriver)
|
db.AddError(ErrUnsupportedDriver)
|
||||||
}
|
}
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec execute raw sql
|
// Exec executes raw sql
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.SQL = strings.Builder{}
|
tx.Statement.SQL = strings.Builder{}
|
||||||
|
|||||||
4
go.mod
4
go.mod
@ -1,8 +1,8 @@
|
|||||||
module gorm.io/gorm
|
module gorm.io/gorm
|
||||||
|
|
||||||
go 1.14
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jinzhu/inflection v1.0.0
|
github.com/jinzhu/inflection v1.0.0
|
||||||
github.com/jinzhu/now v1.1.4
|
github.com/jinzhu/now v1.1.5
|
||||||
)
|
)
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -1,4 +1,4 @@
|
|||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
|||||||
101
gorm.go
101
gorm.go
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -37,6 +38,8 @@ type Config struct {
|
|||||||
DisableAutomaticPing bool
|
DisableAutomaticPing bool
|
||||||
// DisableForeignKeyConstraintWhenMigrating
|
// DisableForeignKeyConstraintWhenMigrating
|
||||||
DisableForeignKeyConstraintWhenMigrating bool
|
DisableForeignKeyConstraintWhenMigrating bool
|
||||||
|
// IgnoreRelationshipsWhenMigrating
|
||||||
|
IgnoreRelationshipsWhenMigrating bool
|
||||||
// DisableNestedTransaction disable nested transaction
|
// DisableNestedTransaction disable nested transaction
|
||||||
DisableNestedTransaction bool
|
DisableNestedTransaction bool
|
||||||
// AllowGlobalUpdate allow global update
|
// AllowGlobalUpdate allow global update
|
||||||
@ -45,6 +48,8 @@ type Config struct {
|
|||||||
QueryFields bool
|
QueryFields bool
|
||||||
// CreateBatchSize default create batch size
|
// CreateBatchSize default create batch size
|
||||||
CreateBatchSize int
|
CreateBatchSize int
|
||||||
|
// TranslateError enabling error translation
|
||||||
|
TranslateError bool
|
||||||
|
|
||||||
// ClauseBuilders clause builder
|
// ClauseBuilders clause builder
|
||||||
ClauseBuilders map[string]clause.ClauseBuilder
|
ClauseBuilders map[string]clause.ClauseBuilder
|
||||||
@ -142,7 +147,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.NamingStrategy == nil {
|
if config.NamingStrategy == nil {
|
||||||
config.NamingStrategy = schema.NamingStrategy{}
|
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Logger == nil {
|
if config.Logger == nil {
|
||||||
@ -175,17 +180,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
|
|
||||||
if config.Dialector != nil {
|
if config.Dialector != nil {
|
||||||
err = config.Dialector.Initialize(db)
|
err = config.Dialector.Initialize(db)
|
||||||
}
|
|
||||||
|
|
||||||
preparedStmt := &PreparedStmtDB{
|
if err != nil {
|
||||||
ConnPool: db.ConnPool,
|
if db, _ := db.DB(); db != nil {
|
||||||
Stmts: map[string]Stmt{},
|
_ = db.Close()
|
||||||
Mux: &sync.RWMutex{},
|
}
|
||||||
PreparedSQL: make([]string, 0, 100),
|
}
|
||||||
}
|
}
|
||||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
|
||||||
|
|
||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
|
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||||
|
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||||
db.ConnPool = preparedStmt
|
db.ConnPool = preparedStmt
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,16 +251,30 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
|
var preparedStmt *PreparedStmtDB
|
||||||
|
|
||||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||||
preparedStmt := v.(*PreparedStmtDB)
|
preparedStmt = v.(*PreparedStmtDB)
|
||||||
|
} else {
|
||||||
|
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||||
|
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := tx.Statement.ConnPool.(type) {
|
||||||
|
case Tx:
|
||||||
|
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||||
|
Tx: t,
|
||||||
|
PreparedStmtDB: preparedStmt,
|
||||||
|
}
|
||||||
|
default:
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
ConnPool: db.Config.ConnPool,
|
ConnPool: db.Config.ConnPool,
|
||||||
Mux: preparedStmt.Mux,
|
Mux: preparedStmt.Mux,
|
||||||
Stmts: preparedStmt.Stmts,
|
Stmts: preparedStmt.Stmts,
|
||||||
}
|
}
|
||||||
txConfig.ConnPool = tx.Statement.ConnPool
|
|
||||||
txConfig.PrepareStmt = true
|
|
||||||
}
|
}
|
||||||
|
txConfig.ConnPool = tx.Statement.ConnPool
|
||||||
|
txConfig.PrepareStmt = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SkipHooks {
|
if config.SkipHooks {
|
||||||
@ -300,7 +319,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
|
|||||||
|
|
||||||
// Debug start debug mode
|
// Debug start debug mode
|
||||||
func (db *DB) Debug() (tx *DB) {
|
func (db *DB) Debug() (tx *DB) {
|
||||||
return db.Session(&Session{
|
tx = db.getInstance()
|
||||||
|
return tx.Session(&Session{
|
||||||
Logger: db.Logger.LogMode(logger.Info),
|
Logger: db.Logger.LogMode(logger.Info),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -336,10 +356,18 @@ func (db *DB) Callback() *callbacks {
|
|||||||
|
|
||||||
// AddError add error to db
|
// AddError add error to db
|
||||||
func (db *DB) AddError(err error) error {
|
func (db *DB) AddError(err error) error {
|
||||||
if db.Error == nil {
|
if err != nil {
|
||||||
db.Error = err
|
if db.Config.TranslateError {
|
||||||
} else if err != nil {
|
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
err = errTranslator.Translate(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Error == nil {
|
||||||
|
db.Error = err
|
||||||
|
} else {
|
||||||
|
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return db.Error
|
return db.Error
|
||||||
}
|
}
|
||||||
@ -347,12 +375,20 @@ func (db *DB) AddError(err error) error {
|
|||||||
// DB returns `*sql.DB`
|
// DB returns `*sql.DB`
|
||||||
func (db *DB) DB() (*sql.DB, error) {
|
func (db *DB) DB() (*sql.DB, error) {
|
||||||
connPool := db.ConnPool
|
connPool := db.ConnPool
|
||||||
|
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
connPool = db.Statement.ConnPool
|
||||||
return dbConnector.GetDBConn()
|
}
|
||||||
|
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||||
|
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||||
|
return sqldb, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||||
return sqldb, nil
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,11 +402,12 @@ func (db *DB) getInstance() *DB {
|
|||||||
if db.clone == 1 {
|
if db.clone == 1 {
|
||||||
// clone with new statement
|
// clone with new statement
|
||||||
tx.Statement = &Statement{
|
tx.Statement = &Statement{
|
||||||
DB: tx,
|
DB: tx,
|
||||||
ConnPool: db.Statement.ConnPool,
|
ConnPool: db.Statement.ConnPool,
|
||||||
Context: db.Statement.Context,
|
Context: db.Statement.Context,
|
||||||
Clauses: map[string]clause.Clause{},
|
Clauses: map[string]clause.Clause{},
|
||||||
Vars: make([]interface{}, 0, 8),
|
Vars: make([]interface{}, 0, 8),
|
||||||
|
SkipHooks: db.Statement.SkipHooks,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// with clone statement
|
// with clone statement
|
||||||
@ -412,7 +449,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||||||
relation, ok := modelSchema.Relationships.Relations[field]
|
relation, ok := modelSchema.Relationships.Relations[field]
|
||||||
isRelation := ok && relation.JoinTable != nil
|
isRelation := ok && relation.JoinTable != nil
|
||||||
if !isRelation {
|
if !isRelation {
|
||||||
return fmt.Errorf("failed to found relation: %s", field)
|
return fmt.Errorf("failed to find relation: %s", field)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range relation.References {
|
for _, ref := range relation.References {
|
||||||
@ -455,12 +492,12 @@ func (db *DB) Use(plugin Plugin) error {
|
|||||||
|
|
||||||
// ToSQL for generate SQL string.
|
// ToSQL for generate SQL string.
|
||||||
//
|
//
|
||||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||||
// .Limit(10).Offset(5)
|
// .Limit(10).Offset(5)
|
||||||
// .Order("name ASC")
|
// .Order("name ASC")
|
||||||
// .First(&User{})
|
// .First(&User{})
|
||||||
// })
|
// })
|
||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||||
stmt := tx.Statement
|
stmt := tx.Statement
|
||||||
|
|||||||
@ -26,6 +26,10 @@ type Plugin interface {
|
|||||||
Initialize(*DB) error
|
Initialize(*DB) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ParamsFilter interface {
|
||||||
|
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||||
|
}
|
||||||
|
|
||||||
// ConnPool db conns pool interface
|
// ConnPool db conns pool interface
|
||||||
type ConnPool interface {
|
type ConnPool interface {
|
||||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
@ -82,3 +86,7 @@ type Rows interface {
|
|||||||
Err() error
|
Err() error
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrorTranslator interface {
|
||||||
|
Translate(err error) error
|
||||||
|
}
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
@ -55,6 +55,7 @@ type Config struct {
|
|||||||
SlowThreshold time.Duration
|
SlowThreshold time.Duration
|
||||||
Colorful bool
|
Colorful bool
|
||||||
IgnoreRecordNotFoundError bool
|
IgnoreRecordNotFoundError bool
|
||||||
|
ParameterizedQueries bool
|
||||||
LogLevel LogLevel
|
LogLevel LogLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,8 +69,8 @@ type Interface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Discard Discard logger will print any log to ioutil.Discard
|
// Discard logger will print any log to io.Discard
|
||||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||||
// Default Default logger
|
// Default Default logger
|
||||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
@ -77,7 +78,7 @@ var (
|
|||||||
IgnoreRecordNotFoundError: false,
|
IgnoreRecordNotFoundError: false,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
})
|
})
|
||||||
// Recorder Recorder logger records running SQL into a recorder instance
|
// Recorder logger records running SQL into a recorder instance
|
||||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Info print info
|
// Info print info
|
||||||
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Info {
|
if l.LogLevel >= Info {
|
||||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn print warn messages
|
// Warn print warn messages
|
||||||
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Warn {
|
if l.LogLevel >= Warn {
|
||||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error print error messages
|
// Error print error messages
|
||||||
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||||
if l.LogLevel >= Error {
|
if l.LogLevel >= Error {
|
||||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trace print sql message
|
// Trace print sql message
|
||||||
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
//
|
||||||
|
//nolint:cyclop
|
||||||
|
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
if l.LogLevel <= Silent {
|
if l.LogLevel <= Silent {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -181,6 +184,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParamsFilter filter params
|
||||||
|
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||||
|
if l.Config.ParameterizedQueries {
|
||||||
|
return sql, nil
|
||||||
|
}
|
||||||
|
return sql, params
|
||||||
|
}
|
||||||
|
|
||||||
type traceRecorder struct {
|
type traceRecorder struct {
|
||||||
Interface
|
Interface
|
||||||
BeginAt time.Time
|
BeginAt time.Time
|
||||||
@ -189,8 +200,8 @@ type traceRecorder struct {
|
|||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// New new trace recorder
|
// New trace recorder
|
||||||
func (l traceRecorder) New() *traceRecorder {
|
func (l *traceRecorder) New() *traceRecorder {
|
||||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,25 @@ func isPrintable(s string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A list of Go types that should be converted to SQL primitives
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
// RegEx matches only numeric values
|
||||||
|
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||||
|
|
||||||
|
func isNumeric(k reflect.Kind) bool {
|
||||||
|
switch k {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return true
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var (
|
var (
|
||||||
@ -75,24 +92,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||||
default:
|
default:
|
||||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = nullStr
|
vars[idx] = nullStr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
if s := string(v); isPrintable(s) {
|
if s := string(v); isPrintable(s) {
|
||||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = escaper + "<binary>" + escaper
|
vars[idx] = escaper + "<binary>" + escaper
|
||||||
}
|
}
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
vars[idx] = utils.ToString(v)
|
vars[idx] = utils.ToString(v)
|
||||||
case float64, float32:
|
case float32:
|
||||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||||
|
case float64:
|
||||||
|
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||||
case string:
|
case string:
|
||||||
v = strings.ReplaceAll(v, "\\"+escaper, "")
|
v = strings.ReplaceAll(v, "\\"+escaper, "")
|
||||||
|
|
||||||
@ -110,6 +129,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
convertParams(v, idx)
|
convertParams(v, idx)
|
||||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||||
|
} else if isNumeric(rv.Kind()) {
|
||||||
|
if rv.CanInt() || rv.CanUint() {
|
||||||
|
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||||
|
} else {
|
||||||
|
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, t := range convertibleTypes {
|
for _, t := range convertibleTypes {
|
||||||
if rv.Type().ConvertibleTo(t) {
|
if rv.Type().ConvertibleTo(t) {
|
||||||
@ -117,7 +142,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,9 +169,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
sql = newSQL.String()
|
sql = newSQL.String()
|
||||||
} else {
|
} else {
|
||||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||||
for idx, v := range vars {
|
|
||||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
|
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
||||||
}
|
num := v[1 : len(v)-1]
|
||||||
|
n, _ := strconv.Atoi(num)
|
||||||
|
|
||||||
|
// position var start from 1 ($1, $2)
|
||||||
|
n -= 1
|
||||||
|
if n >= 0 && n <= len(vars)-1 {
|
||||||
|
return vars[n]
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
|||||||
@ -31,20 +31,24 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func format(v []byte, escaper string) string {
|
func format(v []byte, escaper string) string {
|
||||||
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExplainSQL(t *testing.T) {
|
func TestExplainSQL(t *testing.T) {
|
||||||
type role string
|
type role string
|
||||||
type password []byte
|
type password []byte
|
||||||
|
type intType int
|
||||||
|
type floatType float64
|
||||||
var (
|
var (
|
||||||
tt = now.MustParse("2020-02-23 11:10:10")
|
tt = now.MustParse("2020-02-23 11:10:10")
|
||||||
myrole = role("admin")
|
myrole = role("admin")
|
||||||
pwd = password([]byte("pass"))
|
pwd = password("pass")
|
||||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||||
js = JSON(jsVal)
|
js = JSON(jsVal)
|
||||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||||
es = ExampleStruct{Name: "test", Val: "test"}
|
es = ExampleStruct{Name: "test", Val: "test"}
|
||||||
|
intVal intType = 1
|
||||||
|
floatVal floatType = 1.23
|
||||||
)
|
)
|
||||||
|
|
||||||
results := []struct {
|
results := []struct {
|
||||||
@ -69,19 +73,19 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
||||||
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
||||||
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
||||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
@ -95,6 +99,30 @@ func TestExplainSQL(t *testing.T) {
|
|||||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, E"w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, E"w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||||
|
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
NumericRegexp: nil,
|
||||||
|
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
|
||||||
|
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, r := range results {
|
for idx, r := range results {
|
||||||
|
|||||||
34
migrator.go
34
migrator.go
@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator {
|
|||||||
|
|
||||||
// apply scopes to migrator
|
// apply scopes to migrator
|
||||||
for len(tx.Statement.scopes) > 0 {
|
for len(tx.Statement.scopes) > 0 {
|
||||||
scopes := tx.Statement.scopes
|
tx = tx.executeScopes()
|
||||||
tx.Statement.scopes = nil
|
|
||||||
for _, scope := range scopes {
|
|
||||||
tx = scope(tx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||||
@ -30,9 +26,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
|
|||||||
|
|
||||||
// ViewOption view option
|
// ViewOption view option
|
||||||
type ViewOption struct {
|
type ViewOption struct {
|
||||||
Replace bool
|
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
|
||||||
CheckOption string
|
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
|
||||||
Query *DB
|
Query *DB // required subquery.
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnType column type interface
|
// ColumnType column type interface
|
||||||
@ -51,6 +47,23 @@ type ColumnType interface {
|
|||||||
DefaultValue() (value string, ok bool)
|
DefaultValue() (value string, ok bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Index interface {
|
||||||
|
Table() string
|
||||||
|
Name() string
|
||||||
|
Columns() []string
|
||||||
|
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||||
|
Unique() (unique bool, ok bool)
|
||||||
|
Option() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableType table type interface
|
||||||
|
type TableType interface {
|
||||||
|
Schema() string
|
||||||
|
Name() string
|
||||||
|
Type() string
|
||||||
|
Comment() (comment string, ok bool)
|
||||||
|
}
|
||||||
|
|
||||||
// Migrator migrator interface
|
// Migrator migrator interface
|
||||||
type Migrator interface {
|
type Migrator interface {
|
||||||
// AutoMigrate
|
// AutoMigrate
|
||||||
@ -59,6 +72,7 @@ type Migrator interface {
|
|||||||
// Database
|
// Database
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
FullDataTypeOf(*schema.Field) clause.Expr
|
FullDataTypeOf(*schema.Field) clause.Expr
|
||||||
|
GetTypeAliases(databaseTypeName string) []string
|
||||||
|
|
||||||
// Tables
|
// Tables
|
||||||
CreateTable(dst ...interface{}) error
|
CreateTable(dst ...interface{}) error
|
||||||
@ -66,12 +80,15 @@ type Migrator interface {
|
|||||||
HasTable(dst interface{}) bool
|
HasTable(dst interface{}) bool
|
||||||
RenameTable(oldName, newName interface{}) error
|
RenameTable(oldName, newName interface{}) error
|
||||||
GetTables() (tableList []string, err error)
|
GetTables() (tableList []string, err error)
|
||||||
|
TableType(dst interface{}) (TableType, error)
|
||||||
|
|
||||||
// Columns
|
// Columns
|
||||||
AddColumn(dst interface{}, field string) error
|
AddColumn(dst interface{}, field string) error
|
||||||
DropColumn(dst interface{}, field string) error
|
DropColumn(dst interface{}, field string) error
|
||||||
AlterColumn(dst interface{}, field string) error
|
AlterColumn(dst interface{}, field string) error
|
||||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||||
|
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||||
|
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||||
HasColumn(dst interface{}, field string) bool
|
HasColumn(dst interface{}, field string) bool
|
||||||
RenameColumn(dst interface{}, oldName, field string) error
|
RenameColumn(dst interface{}, oldName, field string) error
|
||||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||||
@ -90,4 +107,5 @@ type Migrator interface {
|
|||||||
DropIndex(dst interface{}, name string) error
|
DropIndex(dst interface{}, name string) error
|
||||||
HasIndex(dst interface{}, name string) bool
|
HasIndex(dst interface{}, name string) bool
|
||||||
RenameIndex(dst interface{}, oldName, newName string) error
|
RenameIndex(dst interface{}, oldName, newName string) error
|
||||||
|
GetIndexes(dst interface{}) ([]Index, error)
|
||||||
}
|
}
|
||||||
|
|||||||
43
migrator/index.go
Normal file
43
migrator/index.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package migrator
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
// Index implements gorm.Index interface
|
||||||
|
type Index struct {
|
||||||
|
TableName string
|
||||||
|
NameValue string
|
||||||
|
ColumnList []string
|
||||||
|
PrimaryKeyValue sql.NullBool
|
||||||
|
UniqueValue sql.NullBool
|
||||||
|
OptionValue string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table return the table name of the index.
|
||||||
|
func (idx Index) Table() string {
|
||||||
|
return idx.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name return the name of the index.
|
||||||
|
func (idx Index) Name() string {
|
||||||
|
return idx.NameValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Columns return the columns of the index
|
||||||
|
func (idx Index) Columns() []string {
|
||||||
|
return idx.ColumnList
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrimaryKey returns the index is primary key or not.
|
||||||
|
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||||
|
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique returns whether the index is unique or not.
|
||||||
|
func (idx Index) Unique() (unique bool, ok bool) {
|
||||||
|
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option return the optional attribute of the index
|
||||||
|
func (idx Index) Option() string {
|
||||||
|
return idx.OptionValue
|
||||||
|
}
|
||||||
@ -3,20 +3,32 @@ package migrator
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
// with a possible trailing non-digit character (\D?).
|
||||||
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
|
|
||||||
)
|
// For example, values that can pass this regular expression are:
|
||||||
|
// - "123"
|
||||||
|
// - "abc456"
|
||||||
|
// -"%$#@789"
|
||||||
|
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||||
|
|
||||||
|
// TODO:? Create const vars for raw sql queries ?
|
||||||
|
|
||||||
|
var _ gorm.Migrator = (*Migrator)(nil)
|
||||||
|
|
||||||
// Migrator m struct
|
// Migrator m struct
|
||||||
type Migrator struct {
|
type Migrator struct {
|
||||||
@ -30,6 +42,16 @@ type Config struct {
|
|||||||
gorm.Dialector
|
gorm.Dialector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type printSQLLogger struct {
|
||||||
|
logger.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||||
|
sql, _ := fc()
|
||||||
|
fmt.Println(sql + ";")
|
||||||
|
l.Interface.Trace(ctx, begin, fc, err)
|
||||||
|
}
|
||||||
|
|
||||||
// GormDataTypeInterface gorm data type interface
|
// GormDataTypeInterface gorm data type interface
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||||
@ -72,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||||||
expr.SQL += " NOT NULL"
|
expr.SQL += " NOT NULL"
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.Unique {
|
|
||||||
expr.SQL += " UNIQUE"
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||||
if field.DefaultValueInterface != nil {
|
if field.DefaultValueInterface != nil {
|
||||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||||
@ -89,23 +107,35 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
|
||||||
|
queryTx = m.DB.Session(&gorm.Session{})
|
||||||
|
execTx = queryTx
|
||||||
|
if m.DB.DryRun {
|
||||||
|
queryTx.DryRun = false
|
||||||
|
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||||
|
}
|
||||||
|
return queryTx, execTx
|
||||||
|
}
|
||||||
|
|
||||||
// AutoMigrate auto migrate values
|
// AutoMigrate auto migrate values
|
||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, true) {
|
for _, value := range m.ReorderModels(values, true) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
queryTx, execTx := m.GetQueryAndExecTx()
|
||||||
if !tx.Migrator().HasTable(value) {
|
if !queryTx.Migrator().HasTable(value) {
|
||||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
|
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var (
|
||||||
|
parseIndexes = stmt.Schema.ParseIndexes()
|
||||||
|
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||||
|
)
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
field := stmt.Schema.FieldsByDBName[dbName]
|
|
||||||
var foundColumn gorm.ColumnType
|
var foundColumn gorm.ColumnType
|
||||||
|
|
||||||
for _, columnType := range columnTypes {
|
for _, columnType := range columnTypes {
|
||||||
@ -117,37 +147,43 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
|
|
||||||
if foundColumn == nil {
|
if foundColumn == nil {
|
||||||
// not found, add column
|
// not found, add column
|
||||||
if err := tx.Migrator().AddColumn(value, dbName); err != nil {
|
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// found, smartly migrate
|
||||||
|
field := stmt.Schema.FieldsByDBName[dbName]
|
||||||
|
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
|
||||||
// found, smart migrate
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
|
if rel.Field.IgnoreMigration {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||||
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
|
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
|
||||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
|
||||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
|
||||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
for _, chk := range parseCheckConstraints {
|
||||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
|
||||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, idx := range parseIndexes {
|
||||||
|
if !queryTx.Migrator().HasIndex(value, idx.Name) {
|
||||||
|
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,7 +210,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
|
|||||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, false) {
|
for _, value := range m.ReorderModels(values, false) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||||
var (
|
var (
|
||||||
createTableSQL = "CREATE TABLE ? ("
|
createTableSQL = "CREATE TABLE ? ("
|
||||||
values = []interface{}{m.CurrentTable(stmt)}
|
values = []interface{}{m.CurrentTable(stmt)}
|
||||||
@ -185,7 +221,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
field := stmt.Schema.FieldsByDBName[dbName]
|
field := stmt.Schema.FieldsByDBName[dbName]
|
||||||
if !field.IgnoreMigration {
|
if !field.IgnoreMigration {
|
||||||
createTableSQL += "? ?"
|
createTableSQL += "? ?"
|
||||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
|
||||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||||
createTableSQL += ","
|
createTableSQL += ","
|
||||||
}
|
}
|
||||||
@ -193,7 +229,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
|
|
||||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||||
createTableSQL += "PRIMARY KEY ?,"
|
createTableSQL += "PRIMARY KEY ?,"
|
||||||
primaryKeys := []interface{}{}
|
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||||
}
|
}
|
||||||
@ -204,8 +240,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||||
if m.CreateIndexAfterCreateTable {
|
if m.CreateIndexAfterCreateTable {
|
||||||
defer func(value interface{}, name string) {
|
defer func(value interface{}, name string) {
|
||||||
if errr == nil {
|
if err == nil {
|
||||||
errr = tx.Migrator().CreateIndex(value, name)
|
err = tx.Migrator().CreateIndex(value, name)
|
||||||
}
|
}
|
||||||
}(value, idx.Name)
|
}(value, idx.Name)
|
||||||
} else {
|
} else {
|
||||||
@ -223,15 +259,18 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
createTableSQL += ","
|
createTableSQL += ","
|
||||||
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
|
if rel.Field.IgnoreMigration {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||||
if constraint.Schema == stmt.Schema {
|
if constraint.Schema == stmt.Schema {
|
||||||
sql, vars := buildConstraint(constraint)
|
sql, vars := constraint.Build()
|
||||||
createTableSQL += sql + ","
|
createTableSQL += sql + ","
|
||||||
values = append(values, vars...)
|
values = append(values, vars...)
|
||||||
}
|
}
|
||||||
@ -239,6 +278,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
|
||||||
|
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
|
||||||
|
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
|
||||||
|
}
|
||||||
|
|
||||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||||
@ -252,8 +296,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
createTableSQL += fmt.Sprint(tableOption)
|
createTableSQL += fmt.Sprint(tableOption)
|
||||||
}
|
}
|
||||||
|
|
||||||
errr = tx.Exec(createTableSQL, values...).Error
|
err = tx.Exec(createTableSQL, values...).Error
|
||||||
return errr
|
return err
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -402,32 +446,58 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
|||||||
|
|
||||||
// MigrateColumn migrate column
|
// MigrateColumn migrate column
|
||||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||||
|
if field.IgnoreMigration {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// found, smart migrate
|
// found, smart migrate
|
||||||
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||||
|
|
||||||
alterColumn := false
|
var (
|
||||||
|
alterColumn bool
|
||||||
|
isSameType = fullDataType == realDataType
|
||||||
|
)
|
||||||
|
|
||||||
// check size
|
if !field.PrimaryKey {
|
||||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
// check type
|
||||||
if length > 0 && field.Size > 0 {
|
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||||
alterColumn = true
|
// check type aliases
|
||||||
} else {
|
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||||
// has size in data type and not equal
|
for _, alias := range aliases {
|
||||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
if strings.HasPrefix(fullDataType, alias) {
|
||||||
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
|
isSameType = true
|
||||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
break
|
||||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
|
}
|
||||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
}
|
||||||
|
|
||||||
|
if !isSameType {
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check precision
|
if !isSameType {
|
||||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
// check size
|
||||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||||
alterColumn = true
|
if length > 0 && field.Size > 0 {
|
||||||
|
alterColumn = true
|
||||||
|
} else {
|
||||||
|
// has size in data type and not equal
|
||||||
|
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||||
|
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||||
|
if !field.PrimaryKey &&
|
||||||
|
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check precision
|
||||||
|
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||||
|
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -439,19 +509,29 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check unique
|
|
||||||
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
|
|
||||||
// not primary key
|
|
||||||
if !field.PrimaryKey {
|
|
||||||
alterColumn = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check default value
|
// check default value
|
||||||
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
|
if !field.PrimaryKey {
|
||||||
// not primary key
|
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||||
if !field.PrimaryKey {
|
dv, dvNotNull := columnType.DefaultValue()
|
||||||
|
if dvNotNull && !currentDefaultNotNull {
|
||||||
|
// default value -> null
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
|
} else if !dvNotNull && currentDefaultNotNull {
|
||||||
|
// null -> default value
|
||||||
|
alterColumn = true
|
||||||
|
} else if currentDefaultNotNull || dvNotNull {
|
||||||
|
switch field.GORMDataType {
|
||||||
|
case schema.Time:
|
||||||
|
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
case schema.Bool:
|
||||||
|
v1, _ := strconv.ParseBool(dv)
|
||||||
|
v2, _ := strconv.ParseBool(field.DefaultValue)
|
||||||
|
alterColumn = v1 != v2
|
||||||
|
default:
|
||||||
|
alterColumn = dv != field.DefaultValue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -463,13 +543,39 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if alterColumn && !field.IgnoreMigration {
|
if alterColumn {
|
||||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||||
|
unique, ok := columnType.Unique()
|
||||||
|
if !ok || field.PrimaryKey {
|
||||||
|
return nil // skip primary key
|
||||||
|
}
|
||||||
|
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||||
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
|
// We're currently only receiving boolean values on `Unique` tag,
|
||||||
|
// so the UniqueConstraint name is fixed
|
||||||
|
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||||
|
if unique && !field.Unique {
|
||||||
|
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||||
|
}
|
||||||
|
if !unique && field.Unique {
|
||||||
|
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||||
columnTypes := make([]gorm.ColumnType, 0)
|
columnTypes := make([]gorm.ColumnType, 0)
|
||||||
@ -499,47 +605,76 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
|||||||
return columnTypes, execErr
|
return columnTypes, execErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateView create view
|
// CreateView create view from Query in gorm.ViewOption.
|
||||||
|
// Query in gorm.ViewOption is a [subquery]
|
||||||
|
//
|
||||||
|
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
|
||||||
|
// q := DB.Model(&User{}).Where("age > ?", 20)
|
||||||
|
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
|
||||||
|
//
|
||||||
|
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
|
||||||
|
// q := DB.Model(&User{})
|
||||||
|
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
|
||||||
|
//
|
||||||
|
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
|
||||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||||
return gorm.ErrNotImplemented
|
if option.Query == nil {
|
||||||
|
return gorm.ErrSubQueryRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
sql := new(strings.Builder)
|
||||||
|
sql.WriteString("CREATE ")
|
||||||
|
if option.Replace {
|
||||||
|
sql.WriteString("OR REPLACE ")
|
||||||
|
}
|
||||||
|
sql.WriteString("VIEW ")
|
||||||
|
m.QuoteTo(sql, name)
|
||||||
|
sql.WriteString(" AS ")
|
||||||
|
|
||||||
|
m.DB.Statement.AddVar(sql, option.Query)
|
||||||
|
|
||||||
|
if option.CheckOption != "" {
|
||||||
|
sql.WriteString(" ")
|
||||||
|
sql.WriteString(option.CheckOption)
|
||||||
|
}
|
||||||
|
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropView drop view
|
// DropView drop view
|
||||||
func (m Migrator) DropView(name string) error {
|
func (m Migrator) DropView(name string) error {
|
||||||
return gorm.ErrNotImplemented
|
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
|
||||||
}
|
|
||||||
|
|
||||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
|
||||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
|
||||||
if constraint.OnDelete != "" {
|
|
||||||
sql += " ON DELETE " + constraint.OnDelete
|
|
||||||
}
|
|
||||||
|
|
||||||
if constraint.OnUpdate != "" {
|
|
||||||
sql += " ON UPDATE " + constraint.OnUpdate
|
|
||||||
}
|
|
||||||
|
|
||||||
var foreignKeys, references []interface{}
|
|
||||||
for _, field := range constraint.ForeignKeys {
|
|
||||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, field := range constraint.References {
|
|
||||||
references = append(references, clause.Column{Name: field.DBName})
|
|
||||||
}
|
|
||||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
//
|
||||||
|
// Deprecated: use GuessConstraintInterfaceAndTable instead.
|
||||||
|
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
|
||||||
|
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||||
|
switch c := constraint.(type) {
|
||||||
|
case *schema.Constraint:
|
||||||
|
return c, nil, table
|
||||||
|
case *schema.CheckConstraint:
|
||||||
|
return nil, c, table
|
||||||
|
default:
|
||||||
|
return nil, nil, table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
|
||||||
|
// nolint:cyclop
|
||||||
|
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
|
||||||
if stmt.Schema == nil {
|
if stmt.Schema == nil {
|
||||||
return nil, nil, stmt.Table
|
return nil, stmt.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
||||||
if chk, ok := checkConstraints[name]; ok {
|
if chk, ok := checkConstraints[name]; ok {
|
||||||
return nil, &chk, stmt.Table
|
return &chk, stmt.Table
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
|
||||||
|
if uni, ok := uniqueConstraints[name]; ok {
|
||||||
|
return &uni, stmt.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
getTable := func(rel *schema.Relationship) string {
|
getTable := func(rel *schema.Relationship) string {
|
||||||
@ -554,7 +689,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
|
|||||||
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
||||||
return constraint, nil, getTable(rel)
|
return constraint, getTable(rel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -562,40 +697,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
|
|||||||
for k := range checkConstraints {
|
for k := range checkConstraints {
|
||||||
if checkConstraints[k].Field == field {
|
if checkConstraints[k].Field == field {
|
||||||
v := checkConstraints[k]
|
v := checkConstraints[k]
|
||||||
return nil, &v, stmt.Table
|
return &v, stmt.Table
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range uniqueConstraints {
|
||||||
|
if uniqueConstraints[k].Field == field {
|
||||||
|
v := uniqueConstraints[k]
|
||||||
|
return &v, stmt.Table
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
||||||
return constraint, nil, getTable(rel)
|
return constraint, getTable(rel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, stmt.Schema.Table
|
return nil, stmt.Schema.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateConstraint create constraint
|
// CreateConstraint create constraint
|
||||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||||
if chk != nil {
|
|
||||||
return m.DB.Exec(
|
|
||||||
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
|
|
||||||
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
|
|
||||||
).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
if constraint != nil {
|
if constraint != nil {
|
||||||
vars := []interface{}{clause.Table{Name: table}}
|
vars := []interface{}{clause.Table{Name: table}}
|
||||||
if stmt.TableExpr != nil {
|
if stmt.TableExpr != nil {
|
||||||
vars[0] = stmt.TableExpr
|
vars[0] = stmt.TableExpr
|
||||||
}
|
}
|
||||||
sql, values := buildConstraint(constraint)
|
sql, values := constraint.Build()
|
||||||
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -603,11 +737,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
|||||||
// DropConstraint drop constraint
|
// DropConstraint drop constraint
|
||||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||||
if constraint != nil {
|
if constraint != nil {
|
||||||
name = constraint.Name
|
name = constraint.GetName()
|
||||||
} else if chk != nil {
|
|
||||||
name = chk.Name
|
|
||||||
}
|
}
|
||||||
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||||
})
|
})
|
||||||
@ -618,11 +750,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
|||||||
var count int64
|
var count int64
|
||||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||||
if constraint != nil {
|
if constraint != nil {
|
||||||
name = constraint.Name
|
name = constraint.GetName()
|
||||||
} else if chk != nil {
|
|
||||||
name = chk.Name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.DB.Raw(
|
return m.DB.Raw(
|
||||||
@ -759,7 +889,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||||
}
|
}
|
||||||
beDependedOn := map[*schema.Schema]bool{}
|
beDependedOn := map[*schema.Schema]bool{}
|
||||||
if err := dep.Parse(value); err != nil {
|
// support for special table name
|
||||||
|
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
|
||||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||||
}
|
}
|
||||||
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||||
@ -767,26 +898,31 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
}
|
}
|
||||||
parsedSchemas[dep.Statement.Schema] = true
|
parsedSchemas[dep.Statement.Schema] = true
|
||||||
|
|
||||||
for _, rel := range dep.Schema.Relationships.Relations {
|
if !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
for _, rel := range dep.Schema.Relationships.Relations {
|
||||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
if rel.Field.IgnoreMigration {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||||
|
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||||
|
}
|
||||||
|
|
||||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||||
beDependedOn[rel.FieldSchema] = true
|
beDependedOn[rel.FieldSchema] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.JoinTable != nil {
|
if rel.JoinTable != nil {
|
||||||
// append join value
|
// append join value
|
||||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||||
if !beDependedOn[rel.FieldSchema] {
|
if !beDependedOn[rel.FieldSchema] {
|
||||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||||
} else {
|
} else {
|
||||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
parseDependence(fieldValue, autoAdd)
|
parseDependence(fieldValue, autoAdd)
|
||||||
}
|
}
|
||||||
parseDependence(joinValue, autoAdd)
|
parseDependence(joinValue, autoAdd)
|
||||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -843,3 +979,18 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
|||||||
}
|
}
|
||||||
return clause.Table{Name: stmt.Table}
|
return clause.Table{Name: stmt.Table}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetIndexes return Indexes []gorm.Index and execErr error
|
||||||
|
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||||
|
return nil, errors.New("not support")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTypeAliases return database type aliases
|
||||||
|
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableType return tableType gorm.TableType and execErr error
|
||||||
|
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
|
||||||
|
return nil, errors.New("not support")
|
||||||
|
}
|
||||||
|
|||||||
33
migrator/table_type.go
Normal file
33
migrator/table_type.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package migrator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TableType table type implements TableType interface
|
||||||
|
type TableType struct {
|
||||||
|
SchemaValue string
|
||||||
|
NameValue string
|
||||||
|
TypeValue string
|
||||||
|
CommentValue sql.NullString
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema returns the schema of the table.
|
||||||
|
func (ct TableType) Schema() string {
|
||||||
|
return ct.SchemaValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name of the table.
|
||||||
|
func (ct TableType) Name() string {
|
||||||
|
return ct.NameValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the type of the table.
|
||||||
|
func (ct TableType) Type() string {
|
||||||
|
return ct.TypeValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment returns the comment of current table.
|
||||||
|
func (ct TableType) Comment() (comment string, ok bool) {
|
||||||
|
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||||
|
}
|
||||||
7
model.go
7
model.go
@ -4,9 +4,10 @@ import "time"
|
|||||||
|
|
||||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||||
// It may be embedded into your model or you may build your own model without it
|
// It may be embedded into your model or you may build your own model without it
|
||||||
// type User struct {
|
//
|
||||||
// gorm.Model
|
// type User struct {
|
||||||
// }
|
// gorm.Model
|
||||||
|
// }
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID uint `gorm:"primarykey"`
|
ID uint `gorm:"primarykey"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
|||||||
128
prepare_stmt.go
128
prepare_stmt.go
@ -3,30 +3,44 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
*sql.Stmt
|
*sql.Stmt
|
||||||
Transaction bool
|
Transaction bool
|
||||||
|
prepared chan struct{}
|
||||||
|
prepareErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts map[string]Stmt
|
Stmts map[string]*Stmt
|
||||||
PreparedSQL []string
|
PreparedSQL []string
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
return &PreparedStmtDB{
|
||||||
return dbConnector.GetDBConn()
|
ConnPool: connPool,
|
||||||
|
Stmts: make(map[string]*Stmt),
|
||||||
|
Mux: &sync.RWMutex{},
|
||||||
|
PreparedSQL: make([]string, 0, 100),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||||
return sqldb, nil
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
return dbConnector.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
return nil, ErrInvalidDB
|
return nil, ErrInvalidDB
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,31 +56,72 @@ func (db *PreparedStmtDB) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sdb *PreparedStmtDB) Reset() {
|
||||||
|
sdb.Mux.Lock()
|
||||||
|
defer sdb.Mux.Unlock()
|
||||||
|
|
||||||
|
for _, stmt := range sdb.Stmts {
|
||||||
|
go stmt.Close()
|
||||||
|
}
|
||||||
|
sdb.PreparedSQL = make([]string, 0, 100)
|
||||||
|
sdb.Stmts = make(map[string]*Stmt)
|
||||||
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||||
db.Mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
return stmt, nil
|
// wait for other goroutines prepared
|
||||||
|
<-stmt.prepared
|
||||||
|
if stmt.prepareErr != nil {
|
||||||
|
return Stmt{}, stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return *stmt, nil
|
||||||
}
|
}
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
|
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
|
||||||
|
|
||||||
// double check
|
// double check
|
||||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||||
return stmt, nil
|
db.Mux.Unlock()
|
||||||
} else if ok {
|
// wait for other goroutines prepared
|
||||||
go stmt.Close()
|
<-stmt.prepared
|
||||||
|
if stmt.prepareErr != nil {
|
||||||
|
return Stmt{}, stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return *stmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cache preparing stmt first
|
||||||
|
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||||
|
db.Stmts[query] = &cacheStmt
|
||||||
|
db.Mux.Unlock()
|
||||||
|
|
||||||
|
// prepare completed
|
||||||
|
defer close(cacheStmt.prepared)
|
||||||
|
|
||||||
|
// Reason why cannot lock conn.PrepareContext
|
||||||
|
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||||
|
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||||
|
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||||
|
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||||
stmt, err := conn.PrepareContext(ctx, query)
|
stmt, err := conn.PrepareContext(ctx, query)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
|
cacheStmt.prepareErr = err
|
||||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
db.Mux.Lock()
|
||||||
|
delete(db.Stmts, query)
|
||||||
|
db.Mux.Unlock()
|
||||||
|
return Stmt{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Stmts[query], err
|
db.Mux.Lock()
|
||||||
|
cacheStmt.Stmt = stmt
|
||||||
|
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||||
|
db.Mux.Unlock()
|
||||||
|
|
||||||
|
return cacheStmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||||
@ -74,6 +129,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
|||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
connPool, err := beginner.BeginTx(ctx, opt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tx, ok := connPool.(Tx); ok {
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||||
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
|||||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = stmt.ExecContext(ctx, args...)
|
result, err = stmt.ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if errors.Is(err, driver.ErrBadConn) {
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
go stmt.Close()
|
go stmt.Close()
|
||||||
@ -95,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
|||||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = stmt.QueryContext(ctx, args...)
|
rows, err = stmt.QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if errors.Is(err, driver.ErrBadConn) {
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
|
|
||||||
@ -114,20 +182,32 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
|||||||
return &sql.Row{}
|
return &sql.Row{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) Ping() error {
|
||||||
|
conn, err := db.GetDBConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
type PreparedStmtTX struct {
|
type PreparedStmtTX struct {
|
||||||
Tx
|
Tx
|
||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||||
|
return db.PreparedStmtDB.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) Commit() error {
|
func (tx *PreparedStmtTX) Commit() error {
|
||||||
if tx.Tx != nil {
|
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||||
return tx.Tx.Commit()
|
return tx.Tx.Commit()
|
||||||
}
|
}
|
||||||
return ErrInvalidTransaction
|
return ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) Rollback() error {
|
func (tx *PreparedStmtTX) Rollback() error {
|
||||||
if tx.Tx != nil {
|
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||||
return tx.Tx.Rollback()
|
return tx.Tx.Rollback()
|
||||||
}
|
}
|
||||||
return ErrInvalidTransaction
|
return ErrInvalidTransaction
|
||||||
@ -137,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||||
if err != nil {
|
if errors.Is(err, driver.ErrBadConn) {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
|
||||||
@ -152,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
|||||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if errors.Is(err, driver.ErrBadConn) {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
|
||||||
@ -170,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
|||||||
}
|
}
|
||||||
return &sql.Row{}
|
return &sql.Row{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) Ping() error {
|
||||||
|
conn, err := tx.GetDBConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return conn.Ping()
|
||||||
|
}
|
||||||
|
|||||||
130
scan.go
130
scan.go
@ -4,10 +4,10 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// prepareValues prepare values slice
|
// prepareValues prepare values slice
|
||||||
@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
values[idx] = field.NewValuePool.Get()
|
values[idx] = field.NewValuePool.Get()
|
||||||
@ -65,26 +65,49 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
joinedNestedSchemaMap := make(map[string]interface{})
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field == nil {
|
||||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
continue
|
||||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
}
|
||||||
} else {
|
|
||||||
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
|
||||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||||
|
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||||
|
} else { // joinFields count is larger than 2 when using join
|
||||||
|
var isNilPtrValue bool
|
||||||
|
var relValue reflect.Value
|
||||||
|
// does not contain raw dbname
|
||||||
|
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||||
|
// current reflect value
|
||||||
|
currentReflectValue := reflectValue
|
||||||
|
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||||
|
for _, joinSchema := range nestedJoinSchemas {
|
||||||
|
fullRels = append(fullRels, joinSchema.Name)
|
||||||
|
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||||
|
if relValue.Kind() == reflect.Ptr {
|
||||||
|
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||||
|
// same nested structure
|
||||||
|
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||||
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
|
isNilPtrValue = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
|
joinedNestedSchemaMap[fullRelsName] = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
currentReflectValue = relValue
|
||||||
}
|
}
|
||||||
|
|
||||||
// release data to pool
|
if !isNilPtrValue { // ignore if value is nil
|
||||||
field.NewValuePool.Put(values[idx])
|
f := joinFields[idx][len(joinFields[idx])-1]
|
||||||
|
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// release data to pool
|
||||||
|
field.NewValuePool.Put(values[idx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,11 +179,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
fields = make([]*schema.Field, len(columns))
|
fields = make([]*schema.Field, len(columns))
|
||||||
selectedColumnsMap = make(map[string]int, len(columns))
|
joinFields [][]*schema.Field
|
||||||
joinFields [][2]*schema.Field
|
sch = db.Statement.Schema
|
||||||
sch = db.Statement.Schema
|
reflectValue = db.Statement.ReflectValue
|
||||||
reflectValue = db.Statement.ReflectValue
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if reflectValue.Kind() == reflect.Interface {
|
if reflectValue.Kind() == reflect.Interface {
|
||||||
@ -193,29 +215,45 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
|
|
||||||
// Not Pluck
|
// Not Pluck
|
||||||
if sch != nil {
|
if sch != nil {
|
||||||
|
matchedFieldCount := make(map[string]int, len(columns))
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
if curIndex, ok := selectedColumnsMap[column]; ok {
|
fields[idx] = field
|
||||||
for fieldIndex, selectField := range sch.Fields[curIndex+1:] {
|
if count, ok := matchedFieldCount[column]; ok {
|
||||||
|
// handle duplicate fields
|
||||||
|
for _, selectField := range sch.Fields {
|
||||||
if selectField.DBName == column && selectField.Readable {
|
if selectField.DBName == column && selectField.Readable {
|
||||||
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
if count == 0 {
|
||||||
fields[idx] = selectField
|
matchedFieldCount[column]++
|
||||||
break
|
fields[idx] = selectField
|
||||||
|
break
|
||||||
|
}
|
||||||
|
count--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fields[idx] = field
|
matchedFieldCount[column] = 1
|
||||||
selectedColumnsMap[column] = idx
|
|
||||||
}
|
}
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
subNameCount := len(names)
|
||||||
|
// nested relation fields
|
||||||
|
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||||
|
relFields = append(relFields, rel.Field)
|
||||||
|
for _, name := range names[1 : subNameCount-1] {
|
||||||
|
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||||
|
relFields = append(relFields, rel.Field)
|
||||||
|
}
|
||||||
|
// lastest name is raw dbname
|
||||||
|
dbName := names[subNameCount-1]
|
||||||
|
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
|
||||||
if len(joinFields) == 0 {
|
if len(joinFields) == 0 {
|
||||||
joinFields = make([][2]*schema.Field, len(columns))
|
joinFields = make([][]*schema.Field, len(columns))
|
||||||
}
|
}
|
||||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
relFields = append(relFields, field)
|
||||||
|
joinFields[idx] = relFields
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -229,11 +267,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var elem reflect.Value
|
var (
|
||||||
|
elem reflect.Value
|
||||||
|
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||||
|
)
|
||||||
|
|
||||||
if !update || reflectValue.Len() == 0 {
|
if !update || reflectValue.Len() == 0 {
|
||||||
update = false
|
update = false
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
if isArrayKind {
|
||||||
|
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||||
|
} else {
|
||||||
|
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||||
|
if reflectValue.Cap() == 0 {
|
||||||
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||||
|
} else {
|
||||||
|
reflectValue.SetLen(0)
|
||||||
|
db.Statement.ReflectValue.Set(reflectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
@ -260,10 +311,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||||
|
|
||||||
if !update {
|
if !update {
|
||||||
if isPtr {
|
if !isPtr {
|
||||||
reflectValue = reflect.Append(reflectValue, elem)
|
elem = elem.Elem()
|
||||||
|
}
|
||||||
|
if isArrayKind {
|
||||||
|
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||||
|
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
reflectValue = reflect.Append(reflectValue, elem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,35 +0,0 @@
|
|||||||
package schema
|
|
||||||
|
|
||||||
import (
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// reg match english letters and midline
|
|
||||||
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
|
|
||||||
|
|
||||||
type Check struct {
|
|
||||||
Name string
|
|
||||||
Constraint string // length(phone) >= 10
|
|
||||||
*Field
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseCheckConstraints parse schema check constraints
|
|
||||||
func (schema *Schema) ParseCheckConstraints() map[string]Check {
|
|
||||||
checks := map[string]Check{}
|
|
||||||
for _, field := range schema.FieldsByDBName {
|
|
||||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
|
||||||
names := strings.Split(chk, ",")
|
|
||||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
|
||||||
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
|
||||||
} else {
|
|
||||||
if names[0] == "" {
|
|
||||||
chk = strings.Join(names[1:], ",")
|
|
||||||
}
|
|
||||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
|
||||||
checks[name] = Check{Name: name, Constraint: chk, Field: field}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return checks
|
|
||||||
}
|
|
||||||
66
schema/constraint.go
Normal file
66
schema/constraint.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// reg match english letters and midline
|
||||||
|
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||||
|
|
||||||
|
type CheckConstraint struct {
|
||||||
|
Name string
|
||||||
|
Constraint string // length(phone) >= 10
|
||||||
|
*Field
|
||||||
|
}
|
||||||
|
|
||||||
|
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||||
|
|
||||||
|
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||||
|
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCheckConstraints parse schema check constraints
|
||||||
|
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||||
|
checks := map[string]CheckConstraint{}
|
||||||
|
for _, field := range schema.FieldsByDBName {
|
||||||
|
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||||
|
names := strings.Split(chk, ",")
|
||||||
|
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||||
|
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||||
|
} else {
|
||||||
|
if names[0] == "" {
|
||||||
|
chk = strings.Join(names[1:], ",")
|
||||||
|
}
|
||||||
|
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||||
|
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return checks
|
||||||
|
}
|
||||||
|
|
||||||
|
type UniqueConstraint struct {
|
||||||
|
Name string
|
||||||
|
Field *Field
|
||||||
|
}
|
||||||
|
|
||||||
|
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||||
|
|
||||||
|
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||||
|
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseUniqueConstraints parse schema unique constraints
|
||||||
|
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||||
|
uniques := make(map[string]UniqueConstraint)
|
||||||
|
for _, field := range schema.Fields {
|
||||||
|
if field.Unique {
|
||||||
|
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||||
|
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return uniques
|
||||||
|
}
|
||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserCheck struct {
|
type UserCheck struct {
|
||||||
@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) {
|
|||||||
t.Fatalf("failed to parse user check, got error %v", err)
|
t.Fatalf("failed to parse user check, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
results := map[string]schema.Check{
|
results := map[string]schema.CheckConstraint{
|
||||||
"name_checker": {
|
"name_checker": {
|
||||||
Name: "name_checker",
|
Name: "name_checker",
|
||||||
Constraint: "name <> 'jinzhu'",
|
Constraint: "name <> 'jinzhu'",
|
||||||
@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseUniqueConstraints(t *testing.T) {
|
||||||
|
type UserUnique struct {
|
||||||
|
Name1 string `gorm:"unique"`
|
||||||
|
Name2 string `gorm:"uniqueIndex"`
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||||
|
}
|
||||||
|
constraints := user.ParseUniqueConstraints()
|
||||||
|
|
||||||
|
results := map[string]schema.UniqueConstraint{
|
||||||
|
"uni_user_uniques_name1": {
|
||||||
|
Name: "uni_user_uniques_name1",
|
||||||
|
Field: &schema.Field{Name: "Name1", Unique: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for k, result := range results {
|
||||||
|
v, ok := constraints[k]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
|
||||||
|
}
|
||||||
|
tests.AssertObjEqual(t, result, v, "Name")
|
||||||
|
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
|
||||||
|
}
|
||||||
|
}
|
||||||
102
schema/field.go
102
schema/field.go
@ -49,6 +49,8 @@ const (
|
|||||||
Bytes DataType = "bytes"
|
Bytes DataType = "bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultAutoIncrementIncrement int64 = 1
|
||||||
|
|
||||||
// Field is the representation of model schema's field
|
// Field is the representation of model schema's field
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
@ -87,6 +89,16 @@ type Field struct {
|
|||||||
Set func(context.Context, reflect.Value, interface{}) error
|
Set func(context.Context, reflect.Value, interface{}) error
|
||||||
Serializer SerializerInterface
|
Serializer SerializerInterface
|
||||||
NewValuePool FieldNewValuePool
|
NewValuePool FieldNewValuePool
|
||||||
|
|
||||||
|
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
|
||||||
|
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
|
||||||
|
// It causes field unnecessarily migration.
|
||||||
|
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
|
||||||
|
UniqueIndex string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (field *Field) BindName() string {
|
||||||
|
return strings.Join(field.BindNames, ".")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseField parses reflect.StructField to Field
|
// ParseField parses reflect.StructField to Field
|
||||||
@ -115,7 +127,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||||
Comment: tagSetting["COMMENT"],
|
Comment: tagSetting["COMMENT"],
|
||||||
AutoIncrementIncrement: 1,
|
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||||
}
|
}
|
||||||
|
|
||||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||||
@ -174,7 +186,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
field.DataType = String
|
field.DataType = String
|
||||||
field.Serializer = v
|
field.Serializer = v
|
||||||
} else {
|
} else {
|
||||||
var serializerName = field.TagSettings["JSON"]
|
serializerName := field.TagSettings["JSON"]
|
||||||
if serializerName == "" {
|
if serializerName == "" {
|
||||||
serializerName = field.TagSettings["SERIALIZER"]
|
serializerName = field.TagSettings["SERIALIZER"]
|
||||||
}
|
}
|
||||||
@ -403,18 +415,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ef.PrimaryKey {
|
if ef.PrimaryKey {
|
||||||
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
|
||||||
ef.PrimaryKey = true
|
|
||||||
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
|
||||||
ef.PrimaryKey = true
|
|
||||||
} else {
|
|
||||||
ef.PrimaryKey = false
|
ef.PrimaryKey = false
|
||||||
|
|
||||||
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||||
ef.AutoIncrement = false
|
ef.AutoIncrement = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if ef.DefaultValue == "" {
|
if !ef.AutoIncrement && ef.DefaultValue == "" {
|
||||||
ef.HasDefaultValue = false
|
ef.HasDefaultValue = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -472,9 +480,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
oldValuerOf := field.ValueOf
|
oldValuerOf := field.ValueOf
|
||||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
value, zero := oldValuerOf(ctx, v)
|
value, zero := oldValuerOf(ctx, v)
|
||||||
if zero {
|
|
||||||
return value, zero
|
|
||||||
}
|
|
||||||
|
|
||||||
s, ok := value.(SerializerValuerInterface)
|
s, ok := value.(SerializerValuerInterface)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -487,7 +492,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
Destination: v,
|
Destination: v,
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
fieldValue: value,
|
fieldValue: value,
|
||||||
}, false
|
}, zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -607,6 +612,22 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(**data)
|
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||||
}
|
}
|
||||||
|
case **int:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||||
|
}
|
||||||
|
case **int8:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||||
|
}
|
||||||
|
case **int16:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||||
|
}
|
||||||
|
case **int32:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||||
|
}
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data)
|
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||||
case int:
|
case int:
|
||||||
@ -643,7 +664,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
@ -652,7 +673,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
@ -671,6 +692,22 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetUint(**data)
|
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||||
}
|
}
|
||||||
|
case **uint:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||||
|
}
|
||||||
|
case **uint8:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||||
|
}
|
||||||
|
case **uint16:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||||
|
}
|
||||||
|
case **uint32:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||||
|
}
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValueOf(ctx, value).SetUint(data)
|
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||||
case uint:
|
case uint:
|
||||||
@ -701,7 +738,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||||
}
|
}
|
||||||
@ -723,6 +760,10 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
if data != nil && *data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||||
}
|
}
|
||||||
|
case **float32:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
|
||||||
|
}
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValueOf(ctx, value).SetFloat(data)
|
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||||
case float32:
|
case float32:
|
||||||
@ -813,7 +854,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
case **time.Time:
|
case **time.Time:
|
||||||
if data != nil {
|
if data != nil && *data != nil {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||||
}
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
@ -849,14 +890,12 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||||
|
return
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
fieldValue := field.ReflectValueOf(ctx, value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
@ -877,14 +916,12 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||||
|
return
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
v, _ = valuer.Value()
|
v, _ = valuer.Value()
|
||||||
@ -913,6 +950,8 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||||
|
serializerType := serializerValue.Type()
|
||||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
if s, ok := v.(*serializer); ok {
|
if s, ok := v.(*serializer); ok {
|
||||||
if s.fieldValue != nil {
|
if s.fieldValue != nil {
|
||||||
@ -920,11 +959,12 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||||
if sameElemType {
|
if sameElemType {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
|
||||||
} else if sameType {
|
} else if sameType {
|
||||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
|
||||||
}
|
}
|
||||||
|
si := reflect.New(serializerType)
|
||||||
|
si.Elem().Set(serializerValue)
|
||||||
|
s.Serializer = si.Interface().(SerializerInterface)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = oldFieldSetter(ctx, value, v)
|
err = oldFieldSetter(ctx, value, v)
|
||||||
@ -936,11 +976,15 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
|
|
||||||
func (field *Field) setupNewValuePool() {
|
func (field *Field) setupNewValuePool() {
|
||||||
if field.Serializer != nil {
|
if field.Serializer != nil {
|
||||||
|
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||||
|
serializerType := serializerValue.Type()
|
||||||
field.NewValuePool = &sync.Pool{
|
field.NewValuePool = &sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
|
si := reflect.New(serializerType)
|
||||||
|
si.Elem().Set(serializerValue)
|
||||||
return &serializer{
|
return &serializer{
|
||||||
Field: field,
|
Field: field,
|
||||||
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
|
Serializer: si.Interface().(SerializerInterface),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,8 +13,8 @@ type Index struct {
|
|||||||
Type string // btree, hash, gist, spgist, gin, and brin
|
Type string // btree, hash, gist, spgist, gin, and brin
|
||||||
Where string
|
Where string
|
||||||
Comment string
|
Comment string
|
||||||
Option string // WITH PARSER parser_name
|
Option string // WITH PARSER parser_name
|
||||||
Fields []IndexOption
|
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||||
}
|
}
|
||||||
|
|
||||||
type IndexOption struct {
|
type IndexOption struct {
|
||||||
@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, index := range indexes {
|
||||||
|
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||||
|
index.Fields[0].Field.UniqueIndex = index.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
return indexes
|
return indexes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
|
"gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserIndex struct {
|
type UserIndex struct {
|
||||||
@ -19,6 +19,7 @@ type UserIndex struct {
|
|||||||
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
||||||
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
||||||
Name7 string `gorm:"index:type"`
|
Name7 string `gorm:"index:type"`
|
||||||
|
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
|
||||||
|
|
||||||
// Composite Index: Flattened structure.
|
// Composite Index: Flattened structure.
|
||||||
Data0A string `gorm:"index:,composite:comp_id0"`
|
Data0A string `gorm:"index:,composite:comp_id0"`
|
||||||
@ -65,7 +66,7 @@ func TestParseIndex(t *testing.T) {
|
|||||||
"idx_name": {
|
"idx_name": {
|
||||||
Name: "idx_name",
|
Name: "idx_name",
|
||||||
Class: "UNIQUE",
|
Class: "UNIQUE",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
|
||||||
},
|
},
|
||||||
"idx_user_indices_name3": {
|
"idx_user_indices_name3": {
|
||||||
Name: "idx_user_indices_name3",
|
Name: "idx_user_indices_name3",
|
||||||
@ -81,7 +82,7 @@ func TestParseIndex(t *testing.T) {
|
|||||||
"idx_user_indices_name4": {
|
"idx_user_indices_name4": {
|
||||||
Name: "idx_user_indices_name4",
|
Name: "idx_user_indices_name4",
|
||||||
Class: "UNIQUE",
|
Class: "UNIQUE",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
|
||||||
},
|
},
|
||||||
"idx_user_indices_name5": {
|
"idx_user_indices_name5": {
|
||||||
Name: "idx_user_indices_name5",
|
Name: "idx_user_indices_name5",
|
||||||
@ -102,18 +103,27 @@ func TestParseIndex(t *testing.T) {
|
|||||||
},
|
},
|
||||||
"idx_id": {
|
"idx_id": {
|
||||||
Name: "idx_id",
|
Name: "idx_id",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||||
},
|
},
|
||||||
"idx_oid": {
|
"idx_oid": {
|
||||||
Name: "idx_oid",
|
Name: "idx_oid",
|
||||||
Class: "UNIQUE",
|
Class: "UNIQUE",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
Name: "type",
|
Name: "type",
|
||||||
Type: "",
|
Type: "",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
||||||
},
|
},
|
||||||
|
"idx_user_indices_name8": {
|
||||||
|
Name: "idx_user_indices_name8",
|
||||||
|
Type: "",
|
||||||
|
Fields: []schema.IndexOption{
|
||||||
|
{Field: &schema.Field{Name: "Name8"}, Length: 10},
|
||||||
|
// Note: Duplicate Columns
|
||||||
|
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
|
||||||
|
},
|
||||||
|
},
|
||||||
"idx_user_indices_comp_id0": {
|
"idx_user_indices_comp_id0": {
|
||||||
Name: "idx_user_indices_comp_id0",
|
Name: "idx_user_indices_comp_id0",
|
||||||
Type: "",
|
Type: "",
|
||||||
@ -146,37 +156,109 @@ func TestParseIndex(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
indices := user.ParseIndexes()
|
CheckIndices(t, results, user.ParseIndexes())
|
||||||
|
}
|
||||||
|
|
||||||
for k, result := range results {
|
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||||
v, ok := indices[k]
|
type IndexTest struct {
|
||||||
if !ok {
|
FieldA string `gorm:"unique;index"` // unique and index
|
||||||
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
|
FieldB string `gorm:"unique"` // unique
|
||||||
}
|
|
||||||
|
|
||||||
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
|
FieldC string `gorm:"index:,unique"` // uniqueIndex
|
||||||
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
|
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
|
||||||
t.Errorf(
|
|
||||||
"index %v %v should equal, expects %v, got %v",
|
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
|
||||||
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
|
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
|
||||||
)
|
|
||||||
|
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
|
||||||
|
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
|
||||||
|
|
||||||
|
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
|
||||||
|
|
||||||
|
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||||
|
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||||
|
}
|
||||||
|
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse user index, got error %v", err)
|
||||||
|
}
|
||||||
|
indices := indexSchema.ParseIndexes()
|
||||||
|
CheckIndices(t, map[string]schema.Index{
|
||||||
|
"idx_index_tests_field_a": {
|
||||||
|
Name: "idx_index_tests_field_a",
|
||||||
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
|
||||||
|
},
|
||||||
|
"idx_index_tests_field_c": {
|
||||||
|
Name: "idx_index_tests_field_c",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
|
||||||
|
},
|
||||||
|
"idx_index_tests_field_d": {
|
||||||
|
Name: "idx_index_tests_field_d",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{
|
||||||
|
{Field: &schema.Field{Name: "FieldD"}},
|
||||||
|
// Note: Duplicate Columns
|
||||||
|
{Field: &schema.Field{Name: "FieldD"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"uniq_field_e1_e2": {
|
||||||
|
Name: "uniq_field_e1_e2",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{
|
||||||
|
{Field: &schema.Field{Name: "FieldE1"}},
|
||||||
|
{Field: &schema.Field{Name: "FieldE2"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"idx_index_tests_field_f1": {
|
||||||
|
Name: "idx_index_tests_field_f1",
|
||||||
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
|
||||||
|
},
|
||||||
|
"uniq_field_f1_f2": {
|
||||||
|
Name: "uniq_field_f1_f2",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{
|
||||||
|
{Field: &schema.Field{Name: "FieldF1"}},
|
||||||
|
{Field: &schema.Field{Name: "FieldF2"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"idx_index_tests_field_g": {
|
||||||
|
Name: "idx_index_tests_field_g",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
|
||||||
|
},
|
||||||
|
"uniq_field_h1_h2": {
|
||||||
|
Name: "uniq_field_h1_h2",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{
|
||||||
|
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
|
||||||
|
{Field: &schema.Field{Name: "FieldH2"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) {
|
||||||
|
for k, ei := range expected {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
ai, ok := actual[k]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected index %q but actual missing", k)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
|
||||||
|
if len(ei.Fields) != len(ai.Fields) {
|
||||||
for idx, ef := range result.Fields {
|
t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields))
|
||||||
rf := v.Fields[idx]
|
return
|
||||||
if rf.Field.Name != ef.Field.Name {
|
|
||||||
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
|
|
||||||
}
|
}
|
||||||
|
for i, ef := range ei.Fields {
|
||||||
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
|
af := ai.Fields[i]
|
||||||
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
|
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length")
|
||||||
t.Errorf(
|
|
||||||
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
|
|
||||||
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
delete(actual, k)
|
||||||
|
}
|
||||||
|
for k := range actual {
|
||||||
|
t.Errorf("unexpected index %q", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,12 @@ import (
|
|||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConstraintInterface database constraint interface
|
||||||
|
type ConstraintInterface interface {
|
||||||
|
GetName() string
|
||||||
|
Build() (sql string, vars []interface{})
|
||||||
|
}
|
||||||
|
|
||||||
// GormDataTypeInterface gorm data type interface
|
// GormDataTypeInterface gorm data type interface
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDataType() string
|
GormDataType() string
|
||||||
|
|||||||
@ -19,6 +19,7 @@ type Namer interface {
|
|||||||
RelationshipFKName(Relationship) string
|
RelationshipFKName(Relationship) string
|
||||||
CheckerName(table, column string) string
|
CheckerName(table, column string) string
|
||||||
IndexName(table, column string) string
|
IndexName(table, column string) string
|
||||||
|
UniqueName(table, column string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replacer replacer interface like strings.Replacer
|
// Replacer replacer interface like strings.Replacer
|
||||||
@ -26,12 +27,15 @@ type Replacer interface {
|
|||||||
Replace(name string) string
|
Replace(name string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Namer = (*NamingStrategy)(nil)
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
type NamingStrategy struct {
|
type NamingStrategy struct {
|
||||||
TablePrefix string
|
TablePrefix string
|
||||||
SingularTable bool
|
SingularTable bool
|
||||||
NameReplacer Replacer
|
NameReplacer Replacer
|
||||||
NoLowerCase bool
|
NoLowerCase bool
|
||||||
|
IdentifierMaxLength int
|
||||||
}
|
}
|
||||||
|
|
||||||
// TableName convert string to table name
|
// TableName convert string to table name
|
||||||
@ -84,17 +88,26 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
|||||||
return ns.formatName("idx", table, ns.toDBName(column))
|
return ns.formatName("idx", table, ns.toDBName(column))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UniqueName generate unique constraint name
|
||||||
|
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||||
|
return ns.formatName("uni", table, ns.toDBName(column))
|
||||||
|
}
|
||||||
|
|
||||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||||
prefix, table, name,
|
prefix, table, name,
|
||||||
}, "_"), ".", "_")
|
}, "_"), ".", "_")
|
||||||
|
|
||||||
if utf8.RuneCountInString(formattedName) > 64 {
|
if ns.IdentifierMaxLength == 0 {
|
||||||
|
ns.IdentifierMaxLength = 64
|
||||||
|
}
|
||||||
|
|
||||||
|
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write([]byte(formattedName))
|
h.Write([]byte(formattedName))
|
||||||
bs := h.Sum(nil)
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||||
}
|
}
|
||||||
return formattedName
|
return formattedName
|
||||||
}
|
}
|
||||||
|
|||||||
@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
|
||||||
|
ns := NamingStrategy{IdentifierMaxLength: 63}
|
||||||
|
|
||||||
|
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||||
|
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
|
||||||
|
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
||||||
ns := NamingStrategy{}
|
ns := NamingStrategy{IdentifierMaxLength: 64}
|
||||||
|
|
||||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||||
|
|||||||
@ -27,6 +27,8 @@ type Relationships struct {
|
|||||||
HasMany []*Relationship
|
HasMany []*Relationship
|
||||||
Many2Many []*Relationship
|
Many2Many []*Relationship
|
||||||
Relations map[string]*Relationship
|
Relations map[string]*Relationship
|
||||||
|
|
||||||
|
EmbeddedRelations map[string]*Relationships
|
||||||
}
|
}
|
||||||
|
|
||||||
type Relationship struct {
|
type Relationship struct {
|
||||||
@ -74,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
if hasPolymorphicRelation(field.TagSettings) {
|
||||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
schema.buildPolymorphicRelation(relation, field)
|
||||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||||
@ -87,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
schema.guessRelation(relation, field, guessHas)
|
schema.guessRelation(relation, field, guessHas)
|
||||||
default:
|
default:
|
||||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||||
|
field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +109,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if schema.err == nil {
|
if schema.err == nil {
|
||||||
schema.Relationships.Relations[relation.Name] = relation
|
schema.setRelation(relation)
|
||||||
switch relation.Type {
|
switch relation.Type {
|
||||||
case HasOne:
|
case HasOne:
|
||||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||||
@ -122,34 +125,100 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||||||
return relation
|
return relation
|
||||||
}
|
}
|
||||||
|
|
||||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
// hasPolymorphicRelation check if has polymorphic relation
|
||||||
// type User struct {
|
// 1. `POLYMORPHIC` tag
|
||||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||||
// }
|
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||||
// type Pet struct {
|
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
return true
|
||||||
// }
|
|
||||||
// type Toy struct {
|
|
||||||
// OwnerID int
|
|
||||||
// OwnerType string
|
|
||||||
// }
|
|
||||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
|
|
||||||
relation.Polymorphic = &Polymorphic{
|
|
||||||
Value: schema.Table,
|
|
||||||
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
|
|
||||||
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||||
|
_, hasId := tagSettings["POLYMORPHICID"]
|
||||||
|
|
||||||
|
return hasType && hasId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (schema *Schema) setRelation(relation *Relationship) {
|
||||||
|
// set non-embedded relation
|
||||||
|
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||||
|
if len(rel.Field.BindNames) > 1 {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
schema.Relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
|
||||||
|
// set embedded relation
|
||||||
|
if len(relation.Field.BindNames) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
relationships := &schema.Relationships
|
||||||
|
for i, name := range relation.Field.BindNames {
|
||||||
|
if i < len(relation.Field.BindNames)-1 {
|
||||||
|
if relationships.EmbeddedRelations == nil {
|
||||||
|
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||||
|
}
|
||||||
|
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||||
|
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||||
|
}
|
||||||
|
relationships = relationships.EmbeddedRelations[name]
|
||||||
|
} else {
|
||||||
|
if relationships.Relations == nil {
|
||||||
|
relationships.Relations = map[string]*Relationship{}
|
||||||
|
}
|
||||||
|
relationships.Relations[relation.Name] = relation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||||
|
//
|
||||||
|
// type User struct {
|
||||||
|
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
// }
|
||||||
|
// type Pet struct {
|
||||||
|
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
// }
|
||||||
|
// type Toy struct {
|
||||||
|
// OwnerID int
|
||||||
|
// OwnerType string
|
||||||
|
// }
|
||||||
|
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||||
|
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||||
|
|
||||||
|
relation.Polymorphic = &Polymorphic{
|
||||||
|
Value: schema.Table,
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
typeName = polymorphic + "Type"
|
||||||
|
typeId = polymorphic + "ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||||
|
typeName = strings.TrimSpace(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||||
|
typeId = strings.TrimSpace(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||||
|
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||||
|
|
||||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relation.Polymorphic.PolymorphicType == nil {
|
if relation.Polymorphic.PolymorphicType == nil {
|
||||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||||
|
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||||
}
|
}
|
||||||
|
|
||||||
if relation.Polymorphic.PolymorphicID == nil {
|
if relation.Polymorphic.PolymorphicID == nil {
|
||||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||||
|
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema.err == nil {
|
if schema.err == nil {
|
||||||
@ -161,10 +230,17 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||||||
primaryKeyField := schema.PrioritizedPrimaryField
|
primaryKeyField := schema.PrioritizedPrimaryField
|
||||||
if len(relation.foreignKeys) > 0 {
|
if len(relation.foreignKeys) > 0 {
|
||||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||||
|
schema, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if primaryKeyField == nil {
|
||||||
|
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||||
|
relation.FieldSchema, schema, field.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// use same data type for foreign keys
|
// use same data type for foreign keys
|
||||||
if copyableDataType(primaryKeyField.DataType) {
|
if copyableDataType(primaryKeyField.DataType) {
|
||||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||||
@ -191,7 +267,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
err error
|
err error
|
||||||
joinTableFields []reflect.StructField
|
joinTableFields []reflect.StructField
|
||||||
fieldsMap = map[string]*Field{}
|
fieldsMap = map[string]*Field{}
|
||||||
ownFieldsMap = map[string]bool{} // fix self join many2many
|
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||||
|
referFieldsMap = map[string]*Field{}
|
||||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||||
)
|
)
|
||||||
@ -229,21 +306,19 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
ownFieldsMap[joinFieldName] = true
|
ownFieldsMap[joinFieldName] = ownField
|
||||||
fieldsMap[joinFieldName] = ownField
|
fieldsMap[joinFieldName] = ownField
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
PkgPath: ownField.StructField.PkgPath,
|
PkgPath: ownField.StructField.PkgPath,
|
||||||
Type: ownField.StructField.Type,
|
Type: ownField.StructField.Type,
|
||||||
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
|
||||||
|
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, relField := range refForeignFields {
|
for idx, relField := range refForeignFields {
|
||||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||||
if len(joinReferences) > idx {
|
|
||||||
joinFieldName = strings.Title(joinReferences[idx])
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||||
if field.Name != relation.FieldSchema.Name {
|
if field.Name != relation.FieldSchema.Name {
|
||||||
@ -253,13 +328,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldsMap[joinFieldName] = relField
|
if len(joinReferences) > idx {
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinFieldName = strings.Title(joinReferences[idx])
|
||||||
Name: joinFieldName,
|
}
|
||||||
PkgPath: relField.StructField.PkgPath,
|
|
||||||
Type: relField.StructField.Type,
|
referFieldsMap[joinFieldName] = relField
|
||||||
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
|
||||||
})
|
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||||
|
fieldsMap[joinFieldName] = relField
|
||||||
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
|
Name: joinFieldName,
|
||||||
|
PkgPath: relField.StructField.PkgPath,
|
||||||
|
Type: relField.StructField.Type,
|
||||||
|
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
|
||||||
|
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
@ -268,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
Tag: `gorm:"-"`,
|
Tag: `gorm:"-"`,
|
||||||
})
|
})
|
||||||
|
|
||||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||||
|
schema.namer); err != nil {
|
||||||
schema.err = err
|
schema.err = err
|
||||||
}
|
}
|
||||||
relation.JoinTable.Name = many2many
|
relation.JoinTable.Name = many2many
|
||||||
@ -315,31 +400,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
f.Size = fieldsMap[f.Name].Size
|
f.Size = fieldsMap[f.Name].Size
|
||||||
}
|
}
|
||||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||||
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
|
||||||
|
|
||||||
if ownPrimaryField {
|
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||||
joinRel.Field = relation.Field
|
joinRel.Field = relation.Field
|
||||||
joinRel.References = append(joinRel.References, &Reference{
|
joinRel.References = append(joinRel.References, &Reference{
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
PrimaryKey: of,
|
||||||
ForeignKey: f,
|
ForeignKey: f,
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
|
relation.References = append(relation.References, &Reference{
|
||||||
|
PrimaryKey: of,
|
||||||
|
ForeignKey: f,
|
||||||
|
OwnPrimaryKey: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||||
if joinRefRel.Field == nil {
|
if joinRefRel.Field == nil {
|
||||||
joinRefRel.Field = relation.Field
|
joinRefRel.Field = relation.Field
|
||||||
}
|
}
|
||||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
PrimaryKey: rf,
|
||||||
|
ForeignKey: f,
|
||||||
|
})
|
||||||
|
|
||||||
|
relation.References = append(relation.References, &Reference{
|
||||||
|
PrimaryKey: rf,
|
||||||
ForeignKey: f,
|
ForeignKey: f,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
relation.References = append(relation.References, &Reference{
|
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
|
||||||
ForeignKey: f,
|
|
||||||
OwnPrimaryKey: ownPrimaryField,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -381,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||||
// case guessEmbeddedHas:
|
// case guessEmbeddedHas:
|
||||||
default:
|
default:
|
||||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||||
|
schema, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,34 +481,31 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
case guessBelongs:
|
case guessBelongs:
|
||||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||||
case guessEmbeddedBelongs:
|
case guessEmbeddedBelongs:
|
||||||
if field.OwnerSchema != nil {
|
if field.OwnerSchema == nil {
|
||||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
|
||||||
} else {
|
|
||||||
reguessOrErr()
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||||
case guessHas:
|
case guessHas:
|
||||||
case guessEmbeddedHas:
|
case guessEmbeddedHas:
|
||||||
if field.OwnerSchema != nil {
|
if field.OwnerSchema == nil {
|
||||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
|
||||||
} else {
|
|
||||||
reguessOrErr()
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(relation.foreignKeys) > 0 {
|
if len(relation.foreignKeys) > 0 {
|
||||||
for _, foreignKey := range relation.foreignKeys {
|
for _, foreignKey := range relation.foreignKeys {
|
||||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
f := foreignSchema.LookUpField(foreignKey)
|
||||||
foreignFields = append(foreignFields, f)
|
if f == nil {
|
||||||
} else {
|
|
||||||
reguessOrErr()
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
foreignFields = append(foreignFields, f)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var primaryFields []*Field
|
primarySchemaName := primarySchema.Name
|
||||||
var primarySchemaName = primarySchema.Name
|
|
||||||
if primarySchemaName == "" {
|
if primarySchemaName == "" {
|
||||||
primarySchemaName = relation.FieldSchema.Name
|
primarySchemaName = relation.FieldSchema.Name
|
||||||
}
|
}
|
||||||
@ -431,6 +520,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
primaryFields = primarySchema.PrimaryFields
|
primaryFields = primarySchema.PrimaryFields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
primaryFieldLoop:
|
||||||
for _, primaryField := range primaryFields {
|
for _, primaryField := range primaryFields {
|
||||||
lookUpName := primarySchemaName + primaryField.Name
|
lookUpName := primarySchemaName + primaryField.Name
|
||||||
if gl == guessBelongs {
|
if gl == guessBelongs {
|
||||||
@ -439,23 +529,33 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
|
|
||||||
lookUpNames := []string{lookUpName}
|
lookUpNames := []string{lookUpName}
|
||||||
if len(primaryFields) == 1 {
|
if len(primaryFields) == 1 {
|
||||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||||
|
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||||
|
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, name := range lookUpNames {
|
||||||
|
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||||
|
foreignFields = append(foreignFields, f)
|
||||||
|
primaryFields = append(primaryFields, primaryField)
|
||||||
|
continue primaryFieldLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, name := range lookUpNames {
|
for _, name := range lookUpNames {
|
||||||
if f := foreignSchema.LookUpField(name); f != nil {
|
if f := foreignSchema.LookUpField(name); f != nil {
|
||||||
foreignFields = append(foreignFields, f)
|
foreignFields = append(foreignFields, f)
|
||||||
primaryFields = append(primaryFields, primaryField)
|
primaryFields = append(primaryFields, primaryField)
|
||||||
break
|
continue primaryFieldLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(foreignFields) == 0 {
|
switch {
|
||||||
|
case len(foreignFields) == 0:
|
||||||
reguessOrErr()
|
reguessOrErr()
|
||||||
return
|
return
|
||||||
} else if len(relation.primaryKeys) > 0 {
|
case len(relation.primaryKeys) > 0:
|
||||||
for idx, primaryKey := range relation.primaryKeys {
|
for idx, primaryKey := range relation.primaryKeys {
|
||||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||||
if len(primaryFields) < idx+1 {
|
if len(primaryFields) < idx+1 {
|
||||||
@ -469,7 +569,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(primaryFields) == 0 {
|
case len(primaryFields) == 0:
|
||||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||||
@ -505,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Constraint is ForeignKey Constraint
|
||||||
type Constraint struct {
|
type Constraint struct {
|
||||||
Name string
|
Name string
|
||||||
Field *Field
|
Field *Field
|
||||||
@ -516,6 +617,31 @@ type Constraint struct {
|
|||||||
OnUpdate string
|
OnUpdate string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||||
|
|
||||||
|
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||||
|
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||||
|
if constraint.OnDelete != "" {
|
||||||
|
sql += " ON DELETE " + constraint.OnDelete
|
||||||
|
}
|
||||||
|
|
||||||
|
if constraint.OnUpdate != "" {
|
||||||
|
sql += " ON UPDATE " + constraint.OnUpdate
|
||||||
|
}
|
||||||
|
|
||||||
|
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||||
|
for _, field := range constraint.ForeignKeys {
|
||||||
|
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||||
|
}
|
||||||
|
|
||||||
|
references := make([]interface{}, 0, len(constraint.References))
|
||||||
|
for _, field := range constraint.References {
|
||||||
|
references = append(references, clause.Column{Name: field.DBName})
|
||||||
|
}
|
||||||
|
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||||
if str == "-" {
|
if str == "-" {
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
||||||
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
||||||
t.Errorf("Failed to parse schema")
|
t.Errorf("Failed to parse schema, got error %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, rel := range relations {
|
for _, rel := range relations {
|
||||||
checkSchemaRelation(t, s, rel)
|
checkSchemaRelation(t, s, rel)
|
||||||
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMany2ManySharedForeignKey(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Kind string
|
||||||
|
ProfileRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
|
||||||
|
Kind string
|
||||||
|
Refer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &User{}, Relation{
|
||||||
|
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||||
|
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
||||||
|
References: []Reference{
|
||||||
|
{"Refer", "User", "UserRefer", "user_profiles", "", true},
|
||||||
|
{"Kind", "User", "Kind", "user_profiles", "", true},
|
||||||
|
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
|
||||||
|
{"Kind", "Profile", "Kind", "user_profiles", "", false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
||||||
type Profile struct {
|
type Profile struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
@ -491,6 +518,319 @@ func TestEmbeddedRelation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedHas(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
type User struct {
|
||||||
|
ID int
|
||||||
|
Cat struct {
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||||
|
Dog struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
UserID int
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Toys": {
|
||||||
|
Name: "Toys",
|
||||||
|
Type: schema.HasMany,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolymorphic(t *testing.T) {
|
||||||
|
t.Run("has one", func(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
RefId int
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has one with only polymorphic type", func(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toy": {
|
||||||
|
Name: "Toy",
|
||||||
|
Type: schema.HasOne,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has many", func(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
OwnerID int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toys": {
|
||||||
|
Name: "Toys",
|
||||||
|
Type: schema.HasMany,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
|
||||||
|
type Toy struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
RefId int
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"Cat": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Toys": {
|
||||||
|
Name: "Toys",
|
||||||
|
Type: schema.HasMany,
|
||||||
|
Schema: "User",
|
||||||
|
FieldSchema: "Toy",
|
||||||
|
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||||
|
References: []Reference{
|
||||||
|
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type Address struct {
|
||||||
|
CountryID int
|
||||||
|
Country Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID int
|
||||||
|
Address struct {
|
||||||
|
ID int
|
||||||
|
Address
|
||||||
|
}
|
||||||
|
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse schema, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||||
|
"PostalAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"VisitingAddress": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"NestedAddress": {
|
||||||
|
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||||
|
"Address": {
|
||||||
|
Relations: map[string]Relation{
|
||||||
|
"Country": {
|
||||||
|
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||||
|
References: []Reference{
|
||||||
|
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestVariableRelation(t *testing.T) {
|
func TestVariableRelation(t *testing.T) {
|
||||||
var result struct {
|
var result struct {
|
||||||
User
|
User
|
||||||
@ -615,7 +955,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
|||||||
s, err := schema.Parse(
|
s, err := schema.Parse(
|
||||||
&Book{},
|
&Book{},
|
||||||
&sync.Map{},
|
&sync.Map{},
|
||||||
schema.NamingStrategy{},
|
schema.NamingStrategy{IdentifierMaxLength: 64},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to parse schema")
|
t.Fatalf("Failed to parse schema")
|
||||||
|
|||||||
138
schema/schema.go
138
schema/schema.go
@ -6,12 +6,27 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type callbackType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||||
|
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||||
|
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||||
|
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||||
|
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||||
|
callbackTypeAfterSave callbackType = "AfterSave"
|
||||||
|
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||||
|
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||||
|
callbackTypeAfterFind callbackType = "AfterFind"
|
||||||
|
)
|
||||||
|
|
||||||
// ErrUnsupportedDataType unsupported data type
|
// ErrUnsupportedDataType unsupported data type
|
||||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||||
|
|
||||||
@ -25,6 +40,7 @@ type Schema struct {
|
|||||||
PrimaryFieldDBNames []string
|
PrimaryFieldDBNames []string
|
||||||
Fields []*Field
|
Fields []*Field
|
||||||
FieldsByName map[string]*Field
|
FieldsByName map[string]*Field
|
||||||
|
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||||
FieldsByDBName map[string]*Field
|
FieldsByDBName map[string]*Field
|
||||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||||
Relationships Relationships
|
Relationships Relationships
|
||||||
@ -67,10 +83,35 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||||
|
//
|
||||||
|
// type Struct struct {
|
||||||
|
// Embedded struct {
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||||
|
// }
|
||||||
|
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||||
|
// }
|
||||||
|
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||||
|
if len(bindNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||||
|
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||||
|
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Tabler interface {
|
type Tabler interface {
|
||||||
TableName() string
|
TableName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TablerWithNamer interface {
|
||||||
|
TableName(Namer) string
|
||||||
|
}
|
||||||
|
|
||||||
// Parse get data type from dialector
|
// Parse get data type from dialector
|
||||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||||
@ -112,7 +153,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
schemaCacheKey = modelType
|
schemaCacheKey = modelType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load exist schmema cache, return if exists
|
// Load exist schema cache, return if exists
|
||||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
@ -125,6 +166,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||||
tableName = tabler.TableName()
|
tableName = tabler.TableName()
|
||||||
}
|
}
|
||||||
|
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||||
|
tableName = tabler.TableName(namer)
|
||||||
|
}
|
||||||
if en, ok := namer.(embeddedNamer); ok {
|
if en, ok := namer.(embeddedNamer); ok {
|
||||||
tableName = en.Table
|
tableName = en.Table
|
||||||
}
|
}
|
||||||
@ -133,20 +177,21 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
Name: modelType.Name(),
|
Name: modelType.Name(),
|
||||||
ModelType: modelType,
|
ModelType: modelType,
|
||||||
Table: tableName,
|
Table: tableName,
|
||||||
FieldsByName: map[string]*Field{},
|
FieldsByName: map[string]*Field{},
|
||||||
FieldsByDBName: map[string]*Field{},
|
FieldsByBindName: map[string]*Field{},
|
||||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
FieldsByDBName: map[string]*Field{},
|
||||||
cacheStore: cacheStore,
|
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||||
namer: namer,
|
cacheStore: cacheStore,
|
||||||
initialized: make(chan struct{}),
|
namer: namer,
|
||||||
|
initialized: make(chan struct{}),
|
||||||
}
|
}
|
||||||
// When the schema initialization is completed, the channel will be closed
|
// When the schema initialization is completed, the channel will be closed
|
||||||
defer close(schema.initialized)
|
defer close(schema.initialized)
|
||||||
|
|
||||||
// Load exist schmema cache, return if exists
|
// Load exist schema cache, return if exists
|
||||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
@ -169,6 +214,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bindName := field.BindName()
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
// nonexistence or shortest path or first appear prioritized if has permission
|
// nonexistence or shortest path or first appear prioritized if has permission
|
||||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||||
@ -177,6 +223,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
schema.FieldsByDBName[field.DBName] = field
|
schema.FieldsByDBName[field.DBName] = field
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
|
||||||
if v != nil && v.PrimaryKey {
|
if v != nil && v.PrimaryKey {
|
||||||
for idx, f := range schema.PrimaryFields {
|
for idx, f := range schema.PrimaryFields {
|
||||||
@ -195,6 +242,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
}
|
}
|
||||||
|
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||||
|
schema.FieldsByBindName[bindName] = field
|
||||||
|
}
|
||||||
|
|
||||||
field.setupValuerAndSetter()
|
field.setupValuerAndSetter()
|
||||||
}
|
}
|
||||||
@ -214,8 +264,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
|
if schema.PrioritizedPrimaryField == nil {
|
||||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
if len(schema.PrimaryFields) == 1 {
|
||||||
|
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||||
|
} else if len(schema.PrimaryFields) > 1 {
|
||||||
|
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||||
|
for _, field := range schema.PrimaryFields {
|
||||||
|
if field.AutoIncrement {
|
||||||
|
schema.PrioritizedPrimaryField = field
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range schema.PrimaryFields {
|
for _, field := range schema.PrimaryFields {
|
||||||
@ -223,7 +283,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.HasDefaultValue && field.DefaultValueInterface == nil {
|
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -242,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
callbackTypes := []callbackType{
|
||||||
for _, name := range callbacks {
|
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||||
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
|
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||||
|
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||||
|
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||||
|
callbackTypeAfterFind,
|
||||||
|
}
|
||||||
|
for _, cbName := range callbackTypes {
|
||||||
|
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||||
switch methodValue.Type().String() {
|
switch methodValue.Type().String() {
|
||||||
case "func(*gorm.DB) error": // TODO hack
|
case "func(*gorm.DB) error": // TODO hack
|
||||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||||
default:
|
default:
|
||||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
|
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -276,6 +342,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
} else {
|
} else {
|
||||||
schema.FieldsByName[field.Name] = field
|
schema.FieldsByName[field.Name] = field
|
||||||
|
schema.FieldsByBindName[field.BindName()] = field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -302,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
return schema, schema.err
|
return schema, schema.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This unrolling is needed to show to the compiler the exact set of methods
|
||||||
|
// that can be used on the modelType.
|
||||||
|
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||||
|
// abandon dead code elimination for the entire binary.
|
||||||
|
// As of go1.22 the compiler supports one special case of a string constant
|
||||||
|
// being passed to MethodByName. For enterprise customers or those building
|
||||||
|
// large binaries, this gives a significant reduction in binary size.
|
||||||
|
// https://github.com/golang/go/issues/62257
|
||||||
|
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||||
|
switch cbType {
|
||||||
|
case callbackTypeBeforeCreate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||||
|
case callbackTypeAfterCreate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||||
|
case callbackTypeBeforeUpdate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||||
|
case callbackTypeAfterUpdate:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||||
|
case callbackTypeBeforeSave:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||||
|
case callbackTypeAfterSave:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||||
|
case callbackTypeBeforeDelete:
|
||||||
|
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||||
|
case callbackTypeAfterDelete:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||||
|
case callbackTypeAfterFind:
|
||||||
|
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||||
|
default:
|
||||||
|
return reflect.ValueOf(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
modelType := reflect.ValueOf(dest).Type()
|
modelType := reflect.ValueOf(dest).Type()
|
||||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||||
|
|||||||
@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||||||
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
|
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range relation.JoinTable.Fields {
|
for i := range relation.JoinTable.Fields {
|
||||||
checkSchemaField(t, r.JoinTable, &f, nil)
|
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddedRelations struct {
|
||||||
|
Relations map[string]Relation
|
||||||
|
EmbeddedRelations map[string]EmbeddedRelations
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||||
|
for name, relations := range actual {
|
||||||
|
rs := expected[name]
|
||||||
|
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||||
|
if len(relations.Relations) != len(rs.Relations) {
|
||||||
|
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||||
|
}
|
||||||
|
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||||
|
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||||
|
}
|
||||||
|
for n, rel := range relations.Relations {
|
||||||
|
if r, ok := rs.Relations[n]; !ok {
|
||||||
|
t.Errorf("failed to find relation by name %s", n)
|
||||||
|
} else {
|
||||||
|
checkSchemaRelation(t, &schema.Schema{
|
||||||
|
Relationships: schema.Relationships{
|
||||||
|
Relations: map[string]*schema.Relationship{n: rel},
|
||||||
|
},
|
||||||
|
}, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||||
for k, v := range values {
|
for k, v := range values {
|
||||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||||
|
|||||||
@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
|||||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range fields {
|
for i := range fields {
|
||||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||||
f.Creatable = true
|
f.Creatable = true
|
||||||
f.Updatable = true
|
f.Updatable = true
|
||||||
f.Readable = true
|
f.Readable = true
|
||||||
@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
|||||||
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range fields {
|
for i := range fields {
|
||||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||||
f.Creatable = true
|
f.Creatable = true
|
||||||
f.Updatable = true
|
f.Updatable = true
|
||||||
f.Readable = true
|
f.Readable = true
|
||||||
@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
|
||||||
|
type Product struct {
|
||||||
|
ProductID uint `gorm:"primaryKey;autoIncrement"`
|
||||||
|
LanguageCode uint `gorm:"primaryKey"`
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type ProductNonAutoIncrement struct {
|
||||||
|
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
|
||||||
|
LanguageCode uint `gorm:"primaryKey"`
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prioritizedPrimaryField := schema.Field{
|
||||||
|
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
|
||||||
|
}
|
||||||
|
|
||||||
|
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
|
||||||
|
|
||||||
|
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
|
||||||
|
f.Creatable = true
|
||||||
|
f.Updatable = true
|
||||||
|
f.Readable = true
|
||||||
|
})
|
||||||
|
|
||||||
|
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
|
||||||
|
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -70,8 +70,7 @@ type SerializerValuerInterface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JSONSerializer json serializer
|
// JSONSerializer json serializer
|
||||||
type JSONSerializer struct {
|
type JSONSerializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements serializer interface
|
// Scan implements serializer interface
|
||||||
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
@ -88,7 +87,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
if len(bytes) > 0 {
|
||||||
|
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
@ -98,12 +99,17 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||||||
// Value implements serializer interface
|
// Value implements serializer interface
|
||||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
result, err := json.Marshal(fieldValue)
|
result, err := json.Marshal(fieldValue)
|
||||||
|
if string(result) == "null" {
|
||||||
|
if field.TagSettings["NOT NULL"] != "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return string(result), err
|
return string(result), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnixSecondSerializer json serializer
|
// UnixSecondSerializer json serializer
|
||||||
type UnixSecondSerializer struct {
|
type UnixSecondSerializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements serializer interface
|
// Scan implements serializer interface
|
||||||
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
@ -117,9 +123,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
|
|||||||
|
|
||||||
// Value implements serializer interface
|
// Value implements serializer interface
|
||||||
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||||
|
rv := reflect.ValueOf(fieldValue)
|
||||||
switch v := fieldValue.(type) {
|
switch v := fieldValue.(type) {
|
||||||
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||||
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
|
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||||
|
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||||
|
if rv.IsZero() {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||||
}
|
}
|
||||||
@ -127,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GobSerializer gob serializer
|
// GobSerializer gob serializer
|
||||||
type GobSerializer struct {
|
type GobSerializer struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements serializer interface
|
// Scan implements serializer interface
|
||||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
@ -142,8 +153,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||||
}
|
}
|
||||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
if len(bytesValue) > 0 {
|
||||||
err = decoder.Decode(fieldValue.Interface())
|
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||||
|
err = decoder.Decode(fieldValue.Interface())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
return
|
return
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package schema
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
|
|||||||
return tag
|
return tag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
|
||||||
|
t := tag.Get("gorm")
|
||||||
|
if strings.Contains(t, value) {
|
||||||
|
return tag
|
||||||
|
}
|
||||||
|
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
|
||||||
|
}
|
||||||
|
|
||||||
// GetRelationsValues get relations's values from a reflect value
|
// GetRelationsValues get relations's values from a reflect value
|
||||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||||
for _, rel := range rels {
|
for _, rel := range rels {
|
||||||
@ -106,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
|||||||
notZero, zero bool
|
notZero, zero bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if reflectValue.Kind() == reflect.Ptr ||
|
||||||
|
reflectValue.Kind() == reflect.Interface {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||||
@ -124,7 +138,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
|||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
elem := reflectValue.Index(i)
|
elem := reflectValue.Index(i)
|
||||||
elemKey := elem.Interface()
|
elemKey := elem.Interface()
|
||||||
if elem.Kind() != reflect.Ptr {
|
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
|
||||||
elemKey = elem.Addr().Interface()
|
elemKey = elem.Addr().Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jinzhu/now"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||||
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||||
|
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||||
|
if _, err := now.Parse(v); err == nil {
|
||||||
|
return sql.NullString{String: v, Valid: true}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sql.NullString{Valid: false}
|
||||||
}
|
}
|
||||||
|
|
||||||
type SoftDeleteQueryClause struct {
|
type SoftDeleteQueryClause struct {
|
||||||
Field *schema.Field
|
ZeroValue sql.NullString
|
||||||
|
Field *schema.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd SoftDeleteQueryClause) Name() string {
|
func (sd SoftDeleteQueryClause) Name() string {
|
||||||
@ -78,18 +89,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||||
}})
|
}})
|
||||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
|
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type SoftDeleteUpdateClause struct {
|
type SoftDeleteUpdateClause struct {
|
||||||
Field *schema.Field
|
ZeroValue sql.NullString
|
||||||
|
Field *schema.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd SoftDeleteUpdateClause) Name() string {
|
func (sd SoftDeleteUpdateClause) Name() string {
|
||||||
@ -109,11 +121,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||||
}
|
}
|
||||||
|
|
||||||
type SoftDeleteDeleteClause struct {
|
type SoftDeleteDeleteClause struct {
|
||||||
Field *schema.Field
|
ZeroValue sql.NullString
|
||||||
|
Field *schema.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sd SoftDeleteDeleteClause) Name() string {
|
func (sd SoftDeleteDeleteClause) Name() string {
|
||||||
|
|||||||
101
statement.go
101
statement.go
@ -49,9 +49,12 @@ type Statement struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
Name string
|
Name string
|
||||||
Conds []interface{}
|
Conds []interface{}
|
||||||
On *clause.Where
|
On *clause.Where
|
||||||
|
Selects []string
|
||||||
|
Omits []string
|
||||||
|
JoinType clause.JoinType
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
@ -117,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||||
} else if len(stmt.Schema.DBNames) > 0 {
|
} else if len(stmt.Schema.DBNames) > 0 {
|
||||||
write(v.Raw, stmt.Schema.DBNames[0])
|
write(v.Raw, stmt.Schema.DBNames[0])
|
||||||
|
} else {
|
||||||
|
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
write(v.Raw, v.Name)
|
write(v.Raw, v.Name)
|
||||||
@ -179,6 +184,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
} else {
|
} else {
|
||||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||||
}
|
}
|
||||||
|
case clause.Interface:
|
||||||
|
c := clause.Clause{Name: v.Name()}
|
||||||
|
v.MergeClause(&c)
|
||||||
|
c.Build(stmt)
|
||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
v.Build(stmt)
|
v.Build(stmt)
|
||||||
case driver.Valuer:
|
case driver.Valuer:
|
||||||
@ -304,6 +313,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
conds := make([]clause.Expression, 0, 4)
|
conds := make([]clause.Expression, 0, 4)
|
||||||
args = append([]interface{}{query}, args...)
|
args = append([]interface{}{query}, args...)
|
||||||
for idx, arg := range args {
|
for idx, arg := range args {
|
||||||
|
if arg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if valuer, ok := arg.(driver.Valuer); ok {
|
if valuer, ok := arg.(driver.Valuer); ok {
|
||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
@ -312,9 +324,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
case *DB:
|
case *DB:
|
||||||
for _, scope := range v.Statement.scopes {
|
v.executeScopes()
|
||||||
v = scope(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := cs.Expression.(clause.Where); ok {
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
@ -437,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||||
|
return []clause.Expression{clause.And(conds...)}
|
||||||
}
|
}
|
||||||
return conds
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -447,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return conds
|
if len(conds) > 0 {
|
||||||
|
return []clause.Expression{clause.And(conds...)}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build build sql with clauses names
|
// Build build sql with clauses names
|
||||||
@ -540,8 +554,9 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetColumn set column's value
|
// SetColumn set column's value
|
||||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
//
|
||||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||||
|
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||||
v[name] = value
|
v[name] = value
|
||||||
@ -650,54 +665,62 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
|
var matchName = func() func(tableColumn string) (table, column string) {
|
||||||
|
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||||
|
return func(tableColumn string) (table, column string) {
|
||||||
|
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||||
|
table = matches[1]
|
||||||
|
star := matches[2]
|
||||||
|
columnName := matches[3]
|
||||||
|
if star != "" {
|
||||||
|
return table, star
|
||||||
|
}
|
||||||
|
return table, columnName
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||||
results := map[string]bool{}
|
results := map[string]bool{}
|
||||||
notRestricted := false
|
notRestricted := false
|
||||||
|
|
||||||
// select columns
|
processColumn := func(column string, result bool) {
|
||||||
for _, column := range stmt.Selects {
|
|
||||||
if stmt.Schema == nil {
|
if stmt.Schema == nil {
|
||||||
results[column] = true
|
results[column] = result
|
||||||
} else if column == "*" {
|
} else if column == "*" {
|
||||||
notRestricted = true
|
notRestricted = result
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
results[dbName] = true
|
results[dbName] = result
|
||||||
}
|
}
|
||||||
} else if column == clause.Associations {
|
} else if column == clause.Associations {
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
results[rel.Name] = true
|
results[rel.Name] = result
|
||||||
}
|
}
|
||||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||||
results[field.DBName] = true
|
results[field.DBName] = result
|
||||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
|
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||||
results[matches[1]] = true
|
if col == "*" {
|
||||||
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
|
results[dbName] = result
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
results[col] = result
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
results[column] = true
|
results[column] = result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// select columns
|
||||||
|
for _, column := range stmt.Selects {
|
||||||
|
processColumn(column, true)
|
||||||
|
}
|
||||||
|
|
||||||
// omit columns
|
// omit columns
|
||||||
for _, omit := range stmt.Omits {
|
for _, column := range stmt.Omits {
|
||||||
if stmt.Schema == nil {
|
processColumn(column, false)
|
||||||
results[omit] = false
|
|
||||||
} else if omit == "*" {
|
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
|
||||||
results[dbName] = false
|
|
||||||
}
|
|
||||||
} else if omit == clause.Associations {
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
|
||||||
results[rel.Name] = false
|
|
||||||
}
|
|
||||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
|
||||||
results[field.DBName] = false
|
|
||||||
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
|
||||||
results[matches[1]] = false
|
|
||||||
} else {
|
|
||||||
results[omit] = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
|
|||||||
@ -35,15 +35,36 @@ func TestWhereCloneCorruption(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNilCondition(t *testing.T) {
|
||||||
|
s := new(Statement)
|
||||||
|
if len(s.BuildCondition(nil)) != 0 {
|
||||||
|
t.Errorf("Nil condition should be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNameMatcher(t *testing.T) {
|
func TestNameMatcher(t *testing.T) {
|
||||||
for k, v := range map[string]string{
|
for k, v := range map[string][]string{
|
||||||
"table.name": "name",
|
"table.name": {"table", "name"},
|
||||||
"`table`.`name`": "name",
|
"`table`.`name`": {"table", "name"},
|
||||||
"'table'.'name'": "name",
|
"'table'.'name'": {"table", "name"},
|
||||||
"'table'.name": "name",
|
"'table'.name": {"table", "name"},
|
||||||
|
"table1.name_23": {"table1", "name_23"},
|
||||||
|
"`table_1`.`name23`": {"table_1", "name23"},
|
||||||
|
"'table23'.'name_1'": {"table23", "name_1"},
|
||||||
|
"'table23'.name1": {"table23", "name1"},
|
||||||
|
"'name1'": {"", "name1"},
|
||||||
|
"`name_1`": {"", "name_1"},
|
||||||
|
"`Name_1`": {"", "Name_1"},
|
||||||
|
"`Table`.`nAme`": {"Table", "nAme"},
|
||||||
|
"my_table.*": {"my_table", "*"},
|
||||||
|
"`my_table`.*": {"my_table", "*"},
|
||||||
|
"User__Company.*": {"User__Company", "*"},
|
||||||
|
"`User__Company`.*": {"User__Company", "*"},
|
||||||
|
`"User__Company".*`: {"User__Company", "*"},
|
||||||
|
`"table"."*"`: {"", ""},
|
||||||
} {
|
} {
|
||||||
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
|
if table, column := matchName(k); table != v[0] || column != v[1] {
|
||||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) {
|
|||||||
unexistCompanyID := company.ID + 9999999
|
unexistCompanyID := company.ID + 9999999
|
||||||
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
|
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
|
||||||
if err := DB.Create(&user).Error; err == nil {
|
if err := DB.Create(&user).Error; err == nil {
|
||||||
|
tidbSkip(t, "not support the foreign key feature")
|
||||||
t.Errorf("should have gotten foreign key violation error")
|
t.Errorf("should have gotten foreign key violation error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -224,3 +226,81 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
|
|||||||
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
||||||
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBelongsToDefaultValue(t *testing.T) {
|
||||||
|
type Org struct {
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
type BelongsToUser struct {
|
||||||
|
OrgID string
|
||||||
|
Org Org `gorm:"default:NULL"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Session(&gorm.Session{})
|
||||||
|
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
|
||||||
|
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
|
||||||
|
|
||||||
|
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
|
||||||
|
tx.AutoMigrate(&BelongsToUser{}, &Org{})
|
||||||
|
|
||||||
|
user := &BelongsToUser{
|
||||||
|
Org: Org{
|
||||||
|
ID: "BelongsToUser_Org_1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := DB.Create(&user).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToAssociationUnscoped(t *testing.T) {
|
||||||
|
type ItemParent struct {
|
||||||
|
gorm.Model
|
||||||
|
Logo string `gorm:"not null;type:varchar(50)"`
|
||||||
|
}
|
||||||
|
type ItemChild struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"type:varchar(50)"`
|
||||||
|
ItemParentID uint
|
||||||
|
ItemParent ItemParent
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Session(&gorm.Session{})
|
||||||
|
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
|
||||||
|
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
|
||||||
|
|
||||||
|
item := ItemChild{
|
||||||
|
Name: "name",
|
||||||
|
ItemParent: ItemParent{
|
||||||
|
Logo: "logo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := tx.Create(&item).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create items, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test replace
|
||||||
|
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
||||||
|
Logo: "updated logo",
|
||||||
|
}); err != nil {
|
||||||
|
t.Errorf("failed to replace item parent, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parents []ItemParent
|
||||||
|
if err := tx.Find(&parents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find item parent, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(parents) != 1 {
|
||||||
|
t.Errorf("expected %d parents, got %d", 1, len(parents))
|
||||||
|
}
|
||||||
|
|
||||||
|
// test delete
|
||||||
|
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
|
||||||
|
t.Errorf("failed to delete item parent, got error: %v", err)
|
||||||
|
}
|
||||||
|
if err := tx.Find(&parents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find item parent, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(parents) != 0 {
|
||||||
|
t.Errorf("expected %d parents, got %d", 0, len(parents))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -421,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
|
|||||||
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||||
users := []User{
|
users := []User{
|
||||||
*GetUser("slice-hasmany-1", Config{Toys: 2}),
|
*GetUser("slice-hasmany-1", Config{Toys: 2}),
|
||||||
*GetUser("slice-hasmany-2", Config{Toys: 0}),
|
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
|
||||||
*GetUser("slice-hasmany-3", Config{Toys: 4}),
|
*GetUser("slice-hasmany-3", Config{Toys: 4}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -429,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
|||||||
|
|
||||||
// Count
|
// Count
|
||||||
AssertAssociationCount(t, users, "Toys", 6, "")
|
AssertAssociationCount(t, users, "Toys", 6, "")
|
||||||
|
AssertAssociationCount(t, users, "Tools", 2, "")
|
||||||
|
|
||||||
// Find
|
// Find
|
||||||
var toys []Toy
|
var toys []Toy
|
||||||
@ -436,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
|||||||
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
|
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find Tools (polymorphic with custom type and id)
|
||||||
|
var tools []Tools
|
||||||
|
DB.Model(&users).Association("Tools").Find(&tools)
|
||||||
|
|
||||||
|
if len(tools) != 2 {
|
||||||
|
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
|
||||||
|
}
|
||||||
|
|
||||||
// Append
|
// Append
|
||||||
DB.Model(&users).Association("Toys").Append(
|
DB.Model(&users).Association("Toys").Append(
|
||||||
&Toy{Name: "toy-slice-append-1"},
|
&Toy{Name: "toy-slice-append-1"},
|
||||||
@ -471,3 +481,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
|||||||
DB.Model(&users).Association("Toys").Clear()
|
DB.Model(&users).Association("Toys").Clear()
|
||||||
AssertAssociationCount(t, users, "Toys", 0, "After Clear")
|
AssertAssociationCount(t, users, "Toys", 0, "After Clear")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasManyAssociationUnscoped(t *testing.T) {
|
||||||
|
type ItemContent struct {
|
||||||
|
gorm.Model
|
||||||
|
ItemID uint `gorm:"not null"`
|
||||||
|
Name string `gorm:"not null;type:varchar(50)"`
|
||||||
|
LanguageCode string `gorm:"not null;type:varchar(2)"`
|
||||||
|
}
|
||||||
|
type Item struct {
|
||||||
|
gorm.Model
|
||||||
|
Logo string `gorm:"not null;type:varchar(50)"`
|
||||||
|
Contents []ItemContent `gorm:"foreignKey:ItemID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := DB.Session(&gorm.Session{})
|
||||||
|
tx.Migrator().DropTable(&ItemContent{}, &Item{})
|
||||||
|
tx.AutoMigrate(&ItemContent{}, &Item{})
|
||||||
|
|
||||||
|
item := Item{
|
||||||
|
Logo: "logo",
|
||||||
|
Contents: []ItemContent{
|
||||||
|
{Name: "name", LanguageCode: "en"},
|
||||||
|
{Name: "ar name", LanguageCode: "ar"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := tx.Create(&item).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create items, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test Replace
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
|
||||||
|
{Name: "updated name", LanguageCode: "en"},
|
||||||
|
{Name: "ar updated name", LanguageCode: "ar"},
|
||||||
|
{Name: "le nom", LanguageCode: "fr"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Errorf("failed to replace item content, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 3, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var contents []ItemContent
|
||||||
|
if err := tx.Find(&contents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(contents) != 3 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 3, len(contents))
|
||||||
|
}
|
||||||
|
|
||||||
|
// test delete
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
|
||||||
|
t.Errorf("failed to delete Contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test clear
|
||||||
|
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
|
||||||
|
t.Errorf("failed to clear contents association, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Find(&contents).Error; err != nil {
|
||||||
|
t.Errorf("failed to find contents, got error: %v", err)
|
||||||
|
}
|
||||||
|
if len(contents) != 0 {
|
||||||
|
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,6 +98,8 @@ func TestMany2ManyAssociation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMany2ManyOmitAssociations(t *testing.T) {
|
func TestMany2ManyOmitAssociations(t *testing.T) {
|
||||||
|
tidbSkip(t, "not support the foreign key feature")
|
||||||
|
|
||||||
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
|
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
|
||||||
|
|
||||||
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
|
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
|
||||||
@ -324,3 +330,96 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
|
|||||||
DB.Model(&users).Association("Team").Clear()
|
DB.Model(&users).Association("Team").Clear()
|
||||||
AssertAssociationCount(t, users, "Team", 0, "After Clear")
|
AssertAssociationCount(t, users, "Team", 0, "After Clear")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDuplicateMany2ManyAssociation(t *testing.T) {
|
||||||
|
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||||
|
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||||
|
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
|
||||||
|
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
|
||||||
|
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
|
||||||
|
}}
|
||||||
|
users := []*User{&user1, &user2}
|
||||||
|
var err error
|
||||||
|
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
var findUser1 User
|
||||||
|
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user1, findUser1)
|
||||||
|
|
||||||
|
var findUser2 User
|
||||||
|
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user2, findUser2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentMany2ManyAssociation(t *testing.T) {
|
||||||
|
db, err := OpenTestConnection(&gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open test connection failed, err: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := 3
|
||||||
|
|
||||||
|
var languages []Language
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
language := Language{Code: fmt.Sprintf("consurrent %d", i)}
|
||||||
|
db.Create(&language)
|
||||||
|
languages = append(languages, language)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{}
|
||||||
|
db.Create(&user)
|
||||||
|
db.Preload("Languages").FirstOrCreate(&user)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(user User, language Language) {
|
||||||
|
err := db.Model(&user).Association("Languages").Append(&language)
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
|
}(user, languages[i])
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
var find User
|
||||||
|
err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
|
||||||
|
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
|
||||||
|
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test-company-1",
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
|
||||||
|
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
|
||||||
|
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test-company-1",
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
users := []*User{&user1, &user2}
|
||||||
|
var err error
|
||||||
|
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
var findUser1 User
|
||||||
|
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user1, findUser1)
|
||||||
|
|
||||||
|
var findUser2 User
|
||||||
|
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
AssertEqual(t, user2, findUser2)
|
||||||
|
}
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestForeignKeyConstraints(t *testing.T) {
|
func TestForeignKeyConstraints(t *testing.T) {
|
||||||
|
tidbSkip(t, "not support the foreign key feature")
|
||||||
|
|
||||||
type Profile struct {
|
type Profile struct {
|
||||||
ID uint
|
ID uint
|
||||||
Name string
|
Name string
|
||||||
@ -124,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
|
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
|
||||||
|
tidbSkip(t, "not support the foreign key feature")
|
||||||
|
|
||||||
type Profile struct {
|
type Profile struct {
|
||||||
ID uint
|
ID uint
|
||||||
Name string
|
Name string
|
||||||
@ -284,3 +290,107 @@ func TestAssociationError(t *testing.T) {
|
|||||||
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
||||||
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
myType string
|
||||||
|
emptyQueryClause struct {
|
||||||
|
Field *schema.Field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (myType) QueryClauses(f *schema.Field) []clause.Interface {
|
||||||
|
return []clause.Interface{emptyQueryClause{Field: f}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) Name() string {
|
||||||
|
return "empty"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) MergeClause(*clause.Clause) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||||
|
type Organization struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type Region struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Organizations []Organization `gorm:"many2many:region_orgs;"`
|
||||||
|
}
|
||||||
|
type RegionOrg struct {
|
||||||
|
RegionId uint
|
||||||
|
OrganizationId uint
|
||||||
|
Empty myType
|
||||||
|
}
|
||||||
|
if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil {
|
||||||
|
t.Fatalf("Failed to set up join table, got error: %s", err)
|
||||||
|
}
|
||||||
|
if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate, got error: %s", err)
|
||||||
|
}
|
||||||
|
if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
region := &Region{Name: "Region1"}
|
||||||
|
if err := DB.Create(region).Error; err != nil {
|
||||||
|
t.Fatalf("fail to create region %v", err)
|
||||||
|
}
|
||||||
|
var orgs []Organization
|
||||||
|
|
||||||
|
if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil {
|
||||||
|
t.Fatalf("fail to find region organizations %v", err)
|
||||||
|
} else {
|
||||||
|
AssertEqual(t, len(orgs), 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssociationEmptyUser struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Pets []AssociationEmptyPet
|
||||||
|
}
|
||||||
|
|
||||||
|
type AssociationEmptyPet struct {
|
||||||
|
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||||
|
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "mysql" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||||
|
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||||
|
|
||||||
|
id := uint(100)
|
||||||
|
user := AssociationEmptyUser{
|
||||||
|
ID: id,
|
||||||
|
Name: "jinzhu",
|
||||||
|
Pets: []AssociationEmptyPet{
|
||||||
|
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||||
|
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result AssociationEmptyUser
|
||||||
|
err = DB.Preload("Pets").First(&result, &id).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to find, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result, user)
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkScan(b *testing.B) {
|
||||||
|
user := *GetUser("scan", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var u User
|
||||||
|
b.ResetTimer()
|
||||||
|
for x := 0; x < b.N; x++ {
|
||||||
|
DB.Raw("select * from users where id = ?", user.ID).Scan(&u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkScanSlice(b *testing.B) {
|
||||||
|
DB.Exec("delete from users")
|
||||||
|
for i := 0; i < 10_000; i++ {
|
||||||
|
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u []User
|
||||||
|
b.ResetTimer()
|
||||||
|
for x := 0; x < b.N; x++ {
|
||||||
|
DB.Raw("select * from users").Scan(&u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkScanSlicePointer(b *testing.B) {
|
||||||
|
DB.Exec("delete from users")
|
||||||
|
for i := 0; i < 10_000; i++ {
|
||||||
|
user := *GetUser(fmt.Sprintf("scan-%d", i), Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u []*User
|
||||||
|
b.ResetTimer()
|
||||||
|
for x := 0; x < b.N; x++ {
|
||||||
|
DB.Raw("select * from users").Scan(&u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkUpdate(b *testing.B) {
|
func BenchmarkUpdate(b *testing.B) {
|
||||||
user := *GetUser("find", Config{})
|
user := *GetUser("find", Config{})
|
||||||
DB.Create(&user)
|
DB.Create(&user)
|
||||||
|
|||||||
@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
|
|||||||
func c3(*gorm.DB) {}
|
func c3(*gorm.DB) {}
|
||||||
func c4(*gorm.DB) {}
|
func c4(*gorm.DB) {}
|
||||||
func c5(*gorm.DB) {}
|
func c5(*gorm.DB) {}
|
||||||
|
func c6(*gorm.DB) {}
|
||||||
|
|
||||||
func TestCallbacks(t *testing.T) {
|
func TestCallbacks(t *testing.T) {
|
||||||
type callback struct {
|
type callback struct {
|
||||||
@ -90,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||||
results: []string{"c1", "c5", "c3", "c4"},
|
results: []string{"c1", "c3", "c4", "c5"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||||
@ -112,6 +113,9 @@ func TestCallbacks(t *testing.T) {
|
|||||||
|
|
||||||
for idx, data := range datas {
|
for idx, data := range datas {
|
||||||
db, err := gorm.Open(nil, nil)
|
db, err := gorm.Open(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
callbacks := db.Callback()
|
callbacks := db.Callback()
|
||||||
|
|
||||||
for _, c := range data.callbacks {
|
for _, c := range data.callbacks {
|
||||||
@ -168,3 +172,83 @@ func TestCallbacks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPluginCallbacks(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("plugin_1_fn1", c1)
|
||||||
|
createCallback.After("*").Register("plugin_1_fn2", c2)
|
||||||
|
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// plugin 2
|
||||||
|
createCallback.Before("*").Register("plugin_2_fn1", c3)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.After("*").Register("plugin_2_fn2", c4)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// plugin 3
|
||||||
|
createCallback.Before("*").Register("plugin_3_fn1", c5)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.After("*").Register("plugin_3_fn2", c6)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbacksGet(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("c1", c1)
|
||||||
|
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||||
|
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c1")
|
||||||
|
if cb := createCallback.Get("c2"); cb != nil {
|
||||||
|
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbacksRemove(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("c1", c1)
|
||||||
|
createCallback.After("*").Register("c2", c2)
|
||||||
|
createCallback.Before("c4").Register("c3", c3)
|
||||||
|
createCallback.After("c2").Register("c4", c4)
|
||||||
|
|
||||||
|
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||||
|
createCallback.Remove("c1")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c4")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c2")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.Remove("c3")
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
//
|
||||||
// return c.db.BeginTx(ctx, opts)
|
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||||
// }
|
// return c.db.BeginTx(ctx, opts)
|
||||||
|
// }
|
||||||
|
//
|
||||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||||
tx, err := c.db.BeginTx(ctx, opts)
|
tx, err := c.db.BeginTx(ctx, opts)
|
||||||
@ -100,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) {
|
|||||||
expect: []string{
|
expect: []string{
|
||||||
"SELECT VERSION()",
|
"SELECT VERSION()",
|
||||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +118,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Should open db success, but got %v", err)
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,6 +11,32 @@ import (
|
|||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestCountWithGroup(t *testing.T) {
|
||||||
|
DB.Create([]Company{
|
||||||
|
{Name: "company_count_group_a"},
|
||||||
|
{Name: "company_count_group_a"},
|
||||||
|
{Name: "company_count_group_a"},
|
||||||
|
{Name: "company_count_group_b"},
|
||||||
|
{Name: "company_count_group_c"},
|
||||||
|
})
|
||||||
|
|
||||||
|
var count1 int64
|
||||||
|
if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil {
|
||||||
|
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
if count1 != 1 {
|
||||||
|
t.Errorf("Count with group should be 1, but got count: %v", count1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var count2 int64
|
||||||
|
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
|
||||||
|
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
if count2 != 2 {
|
||||||
|
t.Errorf("Count with group should be 2, but got count: %v", count2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCount(t *testing.T) {
|
func TestCount(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
user1 = *GetUser("count-1", Config{})
|
user1 = *GetUser("count-1", Config{})
|
||||||
@ -142,7 +168,7 @@ func TestCount(t *testing.T) {
|
|||||||
DB.Create(sameUsers)
|
DB.Create(sameUsers)
|
||||||
|
|
||||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
||||||
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
t.Fatalf("Count should be 1, but got count: %v err %v", count11, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var count12 int64
|
var count12 int64
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -476,6 +477,13 @@ func TestOmitWithCreate(t *testing.T) {
|
|||||||
CheckUser(t, result2, user2)
|
CheckUser(t, result2, user2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirstOrCreateNotExistsTable(t *testing.T) {
|
||||||
|
company := Company{Name: "first_or_create_if_not_exists_table"}
|
||||||
|
if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil {
|
||||||
|
t.Errorf("not exists table, but err is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFirstOrCreateWithPrimaryKey(t *testing.T) {
|
func TestFirstOrCreateWithPrimaryKey(t *testing.T) {
|
||||||
company := Company{ID: 100, Name: "company100_with_primarykey"}
|
company := Company{ID: 100, Name: "company100_with_primarykey"}
|
||||||
DB.FirstOrCreate(&company)
|
DB.FirstOrCreate(&company)
|
||||||
@ -540,3 +548,246 @@ func TestFirstOrCreateRowsAffected(t *testing.T) {
|
|||||||
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
|
||||||
|
type CompositeKeyProduct struct {
|
||||||
|
ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key
|
||||||
|
LanguageCode int `gorm:"primaryKey;"` // primary key
|
||||||
|
Code string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate, got error %v", err)
|
||||||
|
}
|
||||||
|
if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prod := &CompositeKeyProduct{
|
||||||
|
LanguageCode: 56,
|
||||||
|
Code: "Code56",
|
||||||
|
Name: "ProductName56",
|
||||||
|
}
|
||||||
|
if err := DB.Create(&prod).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newProd := &CompositeKeyProduct{}
|
||||||
|
if err := DB.First(&newProd).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
|
} else {
|
||||||
|
AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
|
||||||
|
type OnConflictUser struct {
|
||||||
|
ID string
|
||||||
|
Name string `gorm:"default:null"`
|
||||||
|
Email string
|
||||||
|
Mobile string `gorm:"default:'133xxxx'"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Migrator().DropTable(&OnConflictUser{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
err = DB.AutoMigrate(&OnConflictUser{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
u := OnConflictUser{
|
||||||
|
ID: "on-conflict-user-id",
|
||||||
|
Name: "on-conflict-user-name",
|
||||||
|
Email: "on-conflict-user-email",
|
||||||
|
Mobile: "on-conflict-user-mobile",
|
||||||
|
}
|
||||||
|
err = DB.Create(&u).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
u.Name = "on-conflict-user-name-2"
|
||||||
|
u.Email = "on-conflict-user-email-2"
|
||||||
|
u.Mobile = ""
|
||||||
|
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
var u2 OnConflictUser
|
||||||
|
err = DB.Where("id = ?", u.ID).First(&u2).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
AssertEqual(t, u2.Name, "on-conflict-user-name-2")
|
||||||
|
AssertEqual(t, u2.Email, "on-conflict-user-email-2")
|
||||||
|
AssertEqual(t, u2.Mobile, "133xxxx")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||||
|
if !isMysql() {
|
||||||
|
t.Skipf("This test case skipped, because of only supporting for mysql")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 1: one record, create from map[string]interface{}
|
||||||
|
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
|
||||||
|
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := mapValue1["id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result1 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var idVal int64
|
||||||
|
_, ok := mapValue1["id"].(uint)
|
||||||
|
if ok {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
idVal, ok = mapValue1["id"].(int64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ret result missing id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(result1.ID) != idVal {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case2: one record, create from *map[string]interface{}
|
||||||
|
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
|
||||||
|
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := mapValue2["id"]; !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result2 User
|
||||||
|
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = mapValue2["id"].(uint)
|
||||||
|
if ok {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
idVal, ok = mapValue2["id"].(int64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ret result missing id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(result2.ID) != idVal {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 3: records
|
||||||
|
values := []map[string]interface{}{
|
||||||
|
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeLen := len(values)
|
||||||
|
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mariadb with returning, values will be appended with id map
|
||||||
|
if len(values) == beforeLen*2 {
|
||||||
|
t.Skipf("This test case skipped, because the db supports returning")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range values {
|
||||||
|
v, ok := values[i]["id"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
if int64(result.ID) != v.(int64) {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromMapWithTable(t *testing.T) {
|
||||||
|
tableDB := DB.Table("users")
|
||||||
|
supportLastInsertID := isMysql() || isSqlite()
|
||||||
|
|
||||||
|
// case 1: create from map[string]interface{}
|
||||||
|
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
|
||||||
|
if err := tableDB.Create(record).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := record["@id"]; !ok && supportLastInsertID {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res map[string]interface{}
|
||||||
|
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
|
||||||
|
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 2: create from *map[string]interface{}
|
||||||
|
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
|
||||||
|
tableDB2 := DB.Table("users")
|
||||||
|
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := record1["@id"]; !ok && supportLastInsertID {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res1 map[string]interface{}
|
||||||
|
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
|
||||||
|
t.Fatalf("failed to create from map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
|
||||||
|
t.Fatal("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// case 3: create from []map[string]interface{}
|
||||||
|
records := []map[string]interface{}{
|
||||||
|
{"name": "create_from_map_with_table_2", "age": 19},
|
||||||
|
{"name": "create_from_map_with_table_3", "age": 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
tableDB = DB.Table("users")
|
||||||
|
if err := tableDB.Create(&records).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
|
||||||
|
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res2 map[string]interface{}
|
||||||
|
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var res3 map[string]interface{}
|
||||||
|
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
|
||||||
|
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
|
||||||
|
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
|
||||||
|
t.Errorf("failed to create data from map with table, @id != id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only sqlite, postgres support returning
|
// only sqlite, postgres, sqlserver support returning
|
||||||
func TestSoftDeleteReturning(t *testing.T) {
|
func TestSoftDeleteReturning(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteReturning(t *testing.T) {
|
func TestDeleteReturning(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ services:
|
|||||||
mysql:
|
mysql:
|
||||||
image: 'mysql/mysql-server:latest'
|
image: 'mysql/mysql-server:latest'
|
||||||
ports:
|
ports:
|
||||||
- 9910:3306
|
- "9910:3306"
|
||||||
environment:
|
environment:
|
||||||
- MYSQL_DATABASE=gorm
|
- MYSQL_DATABASE=gorm
|
||||||
- MYSQL_USER=gorm
|
- MYSQL_USER=gorm
|
||||||
@ -13,7 +13,7 @@ services:
|
|||||||
postgres:
|
postgres:
|
||||||
image: 'postgres:latest'
|
image: 'postgres:latest'
|
||||||
ports:
|
ports:
|
||||||
- 9920:5432
|
- "9920:5432"
|
||||||
environment:
|
environment:
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- POSTGRES_DB=gorm
|
- POSTGRES_DB=gorm
|
||||||
@ -22,10 +22,16 @@ services:
|
|||||||
mssql:
|
mssql:
|
||||||
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
||||||
ports:
|
ports:
|
||||||
- 9930:1433
|
- "9930:1433"
|
||||||
environment:
|
environment:
|
||||||
|
- TZ=Asia/Shanghai
|
||||||
- ACCEPT_EULA=Y
|
- ACCEPT_EULA=Y
|
||||||
- SA_PASSWORD=LoremIpsum86
|
- SA_PASSWORD=LoremIpsum86
|
||||||
- MSSQL_DB=gorm
|
- MSSQL_DB=gorm
|
||||||
- MSSQL_USER=gorm
|
- MSSQL_USER=gorm
|
||||||
- MSSQL_PASSWORD=LoremIpsum86
|
- MSSQL_PASSWORD=LoremIpsum86
|
||||||
|
tidb:
|
||||||
|
image: 'pingcap/tidb:v6.5.0'
|
||||||
|
ports:
|
||||||
|
- "9940:4000"
|
||||||
|
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
|
||||||
|
|||||||
@ -4,7 +4,9 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -36,7 +38,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||||||
|
|
||||||
type EngadgetPost struct {
|
type EngadgetPost struct {
|
||||||
BasePost BasePost `gorm:"Embedded"`
|
BasePost BasePost `gorm:"Embedded"`
|
||||||
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||||
ImageUrl string
|
ImageUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,13 +76,26 @@ func TestEmbeddedStruct(t *testing.T) {
|
|||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
|
||||||
|
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
|
||||||
var egNews EngadgetPost
|
var egNews EngadgetPost
|
||||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
} else if egNews.BasePost.Title != "engadget_news" {
|
} else if egNews.BasePost.Title != "engadget_news" {
|
||||||
t.Errorf("embedded struct's value should be scanned correctly")
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var egPosts []EngadgetPost
|
||||||
|
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
|
||||||
|
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
|
}
|
||||||
|
expectAuthors := []string{"Edward", "George"}
|
||||||
|
for i, post := range egPosts {
|
||||||
|
t.Log(i, post.Author)
|
||||||
|
if want := expectAuthors[i]; post.Author.Name != want {
|
||||||
|
t.Errorf("expected author %s got %s", want, post.Author.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||||
@ -90,9 +105,21 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
|||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Author struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
Age int
|
||||||
|
Content Content
|
||||||
|
ContentPtr *Content
|
||||||
|
Birthday time.Time
|
||||||
|
BirthdayPtr *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type HNPost struct {
|
type HNPost struct {
|
||||||
*BasePost
|
*BasePost
|
||||||
Upvotes int32
|
Upvotes int32
|
||||||
|
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Migrator().DropTable(&HNPost{})
|
DB.Migrator().DropTable(&HNPost{})
|
||||||
@ -110,6 +137,52 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
|||||||
if hnPost.Title != "embedded_pointer_type" {
|
if hnPost.Title != "embedded_pointer_type" {
|
||||||
t.Errorf("Should find correct value for embedded pointer type")
|
t.Errorf("Should find correct value for embedded pointer type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hnPost.Author != nil {
|
||||||
|
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Round(time.Second)
|
||||||
|
NewPost := HNPost{
|
||||||
|
BasePost: &BasePost{Title: "embedded_pointer_type2"},
|
||||||
|
Author: &Author{
|
||||||
|
Name: "test",
|
||||||
|
Content: Content{"test"},
|
||||||
|
ContentPtr: nil,
|
||||||
|
Birthday: now,
|
||||||
|
BirthdayPtr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DB.Create(&NewPost)
|
||||||
|
|
||||||
|
hnPost = HNPost{}
|
||||||
|
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hnPost.Title != NewPost.Title {
|
||||||
|
t.Errorf("Should find correct value for embedded pointer type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if hnPost.Author.Name != NewPost.Author.Name {
|
||||||
|
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
|
||||||
|
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hnPost.Author.ContentPtr != nil {
|
||||||
|
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
|
||||||
|
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hnPost.Author.BirthdayPtr != nil {
|
||||||
|
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Content struct {
|
type Content struct {
|
||||||
@ -117,18 +190,26 @@ type Content struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c Content) Value() (driver.Value, error) {
|
func (c Content) Value() (driver.Value, error) {
|
||||||
return json.Marshal(c)
|
// mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
|
||||||
|
b, err := json.Marshal(c)
|
||||||
|
return string(b[:]), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Content) Scan(src interface{}) error {
|
func (c *Content) Scan(src interface{}) error {
|
||||||
b, ok := src.([]byte)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("Embedded.Scan byte assertion failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
var value Content
|
var value Content
|
||||||
if err := json.Unmarshal(b, &value); err != nil {
|
str, ok := src.(string)
|
||||||
return err
|
if !ok {
|
||||||
|
byt, ok := src.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("Embedded.Scan byte assertion failed")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(byt, &value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := json.Unmarshal([]byte(str), &value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
*c = value
|
*c = value
|
||||||
@ -155,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddedRelations(t *testing.T) {
|
func TestEmbeddedRelations(t *testing.T) {
|
||||||
|
type EmbUser struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Age uint
|
||||||
|
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
|
||||||
|
}
|
||||||
|
|
||||||
type AdvancedUser struct {
|
type AdvancedUser struct {
|
||||||
User `gorm:"embedded"`
|
EmbUser `gorm:"embedded"`
|
||||||
Advanced bool
|
Advanced bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,3 +256,29 @@ func TestEmbeddedRelations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedTagSetting(t *testing.T) {
|
||||||
|
type Tag1 struct {
|
||||||
|
Id int64 `gorm:"autoIncrement"`
|
||||||
|
}
|
||||||
|
type Tag2 struct {
|
||||||
|
Id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedTag struct {
|
||||||
|
Tag1 Tag1 `gorm:"Embedded;"`
|
||||||
|
Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&EmbeddedTag{})
|
||||||
|
err := DB.Migrator().AutoMigrate(&EmbeddedTag{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
t1 := EmbeddedTag{Name: "embedded_tag"}
|
||||||
|
err = DB.Save(&t1).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
if t1.Tag1.Id == 0 {
|
||||||
|
t.Errorf("embedded struct's primary field should be rewrited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
111
tests/error_translator_test.go
Normal file
111
tests/error_translator_test.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
|
||||||
|
// it shouldn't translate error when the TranslateError flag is false
|
||||||
|
translatedErr := errors.New("translated error")
|
||||||
|
untranslatedErr := errors.New("some random error")
|
||||||
|
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
|
||||||
|
|
||||||
|
err := db.AddError(untranslatedErr)
|
||||||
|
if !errors.Is(err, untranslatedErr) {
|
||||||
|
t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// it should translate error when the TranslateError flag is true
|
||||||
|
db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true})
|
||||||
|
|
||||||
|
err = db.AddError(untranslatedErr)
|
||||||
|
if !errors.Is(err, translatedErr) {
|
||||||
|
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
|
||||||
|
type City struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"unique"`
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect database, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||||
|
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&City{})
|
||||||
|
|
||||||
|
if err = db.AutoMigrate(&City{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate cities table, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Create(&City{Name: "Kabul"}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create record: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Create(&City{Name: "Kabul"}).Error
|
||||||
|
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
|
||||||
|
tidbSkip(t, "not support the foreign key feature")
|
||||||
|
|
||||||
|
type City struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"unique"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Museum struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"unique"`
|
||||||
|
CityID uint
|
||||||
|
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect database, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||||
|
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&City{}, &Museum{})
|
||||||
|
|
||||||
|
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
|
||||||
|
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
city := City{Name: "Amsterdam"}
|
||||||
|
|
||||||
|
err = db.Create(&city).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create city: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create museum: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
|
||||||
|
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
|
||||||
|
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
tests/go.mod
41
tests/go.mod
@ -1,18 +1,39 @@
|
|||||||
module gorm.io/gorm/tests
|
module gorm.io/gorm/tests
|
||||||
|
|
||||||
go 1.14
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
github.com/google/uuid v1.6.0
|
||||||
github.com/google/uuid v1.3.0
|
|
||||||
github.com/jinzhu/now v1.1.5
|
github.com/jinzhu/now v1.1.5
|
||||||
github.com/lib/pq v1.10.5
|
github.com/lib/pq v1.10.9
|
||||||
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
|
github.com/stretchr/testify v1.9.0
|
||||||
gorm.io/driver/mysql v1.3.3
|
gorm.io/driver/mysql v1.5.6
|
||||||
gorm.io/driver/postgres v1.3.5
|
gorm.io/driver/postgres v1.5.7
|
||||||
gorm.io/driver/sqlite v1.3.2
|
gorm.io/driver/sqlite v1.5.5
|
||||||
gorm.io/driver/sqlserver v1.3.2
|
gorm.io/driver/sqlserver v1.5.3
|
||||||
gorm.io/gorm v1.23.4
|
gorm.io/gorm v1.25.8
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/go-sql-driver/mysql v1.8.0 // indirect
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
|
github.com/microsoft/go-mssqldb v1.7.0 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||||
|
golang.org/x/crypto v0.21.0 // indirect
|
||||||
|
golang.org/x/text v0.14.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
||||||
|
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
|
||||||
|
|||||||
@ -3,9 +3,19 @@ package tests_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestOpen(t *testing.T) {
|
||||||
|
dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc
|
||||||
|
_, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("should returns error but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReturningWithNullToZeroValues(t *testing.T) {
|
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||||
dialect := DB.Dialector.Name()
|
dialect := DB.Dialector.Name()
|
||||||
switch dialect {
|
switch dialect {
|
||||||
|
|||||||
@ -1,12 +1,15 @@
|
|||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,6 +23,7 @@ type Config struct {
|
|||||||
Languages int
|
Languages int
|
||||||
Friends int
|
Friends int
|
||||||
NamedPet bool
|
NamedPet bool
|
||||||
|
Tools int
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(name string, config Config) *User {
|
func GetUser(name string, config Config) *User {
|
||||||
@ -44,6 +48,10 @@ func GetUser(name string, config Config) *User {
|
|||||||
user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
|
user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i := 0; i < config.Tools; i++ {
|
||||||
|
user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)})
|
||||||
|
}
|
||||||
|
|
||||||
if config.Company {
|
if config.Company {
|
||||||
user.Company = Company{Name: "company-" + name}
|
user.Company = Company{Name: "company-" + name}
|
||||||
}
|
}
|
||||||
@ -73,13 +81,22 @@ func GetUser(name string, config Config) *User {
|
|||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) {
|
||||||
|
doCheckPet(t, pet, expect, true)
|
||||||
|
}
|
||||||
|
|
||||||
func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
||||||
|
doCheckPet(t, pet, expect, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) {
|
||||||
if pet.ID != 0 {
|
if pet.ID != 0 {
|
||||||
var newPet Pet
|
var newPet Pet
|
||||||
if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil {
|
if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil {
|
||||||
t.Fatalf("errors happened when query: %v", err)
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
} else {
|
} else {
|
||||||
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||||
|
AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,17 +109,27 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckUserUnscoped(t *testing.T, user User, expect User) {
|
||||||
|
doCheckUser(t, user, expect, true)
|
||||||
|
}
|
||||||
|
|
||||||
func CheckUser(t *testing.T, user User, expect User) {
|
func CheckUser(t *testing.T, user User, expect User) {
|
||||||
|
doCheckUser(t, user, expect, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
|
||||||
if user.ID != 0 {
|
if user.ID != 0 {
|
||||||
var newUser User
|
var newUser User
|
||||||
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||||
t.Fatalf("errors happened when query: %v", err)
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
} else {
|
} else {
|
||||||
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday",
|
||||||
|
"CompanyID", "ManagerID", "Active")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID",
|
||||||
|
"ManagerID", "Active")
|
||||||
|
|
||||||
t.Run("Account", func(t *testing.T) {
|
t.Run("Account", func(t *testing.T) {
|
||||||
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
||||||
@ -112,8 +139,9 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
t.Errorf("Account's foreign key should be saved")
|
t.Errorf("Account's foreign key should be saved")
|
||||||
} else {
|
} else {
|
||||||
var account Account
|
var account Account
|
||||||
DB.First(&account, "user_id = ?", user.ID)
|
db(unscoped).First(&account, "user_id = ?", user.ID)
|
||||||
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
|
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID",
|
||||||
|
"Number")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -135,7 +163,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
if pet == nil || expect.Pets[idx] == nil {
|
if pet == nil || expect.Pets[idx] == nil {
|
||||||
t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet)
|
t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet)
|
||||||
} else {
|
} else {
|
||||||
CheckPet(t, *pet, *expect.Pets[idx])
|
doCheckPet(t, *pet, *expect.Pets[idx], unscoped)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -172,8 +200,11 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
t.Errorf("Manager's foreign key should be saved")
|
t.Errorf("Manager's foreign key should be saved")
|
||||||
} else {
|
} else {
|
||||||
var manager User
|
var manager User
|
||||||
DB.First(&manager, "id = ?", *user.ManagerID)
|
db(unscoped).First(&manager, "id = ?", *user.ManagerID)
|
||||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||||
|
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
|
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||||
|
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
}
|
}
|
||||||
} else if user.ManagerID != nil {
|
} else if user.ManagerID != nil {
|
||||||
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
||||||
@ -194,7 +225,8 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for idx, team := range user.Team {
|
for idx, team := range user.Team {
|
||||||
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||||
|
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -229,7 +261,34 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for idx, friend := range user.Friends {
|
for idx, friend := range user.Friends {
|
||||||
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
|
||||||
|
"Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tidbSkip(t *testing.T, reason string) {
|
||||||
|
if isTiDB() {
|
||||||
|
t.Skipf("This test case skipped, because of TiDB '%s'", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTiDB() bool {
|
||||||
|
return os.Getenv("GORM_DIALECT") == "tidb"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMysql() bool {
|
||||||
|
return os.Getenv("GORM_DIALECT") == "mysql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSqlite() bool {
|
||||||
|
return os.Getenv("GORM_DIALECT") == "sqlite"
|
||||||
|
}
|
||||||
|
|
||||||
|
func db(unscoped bool) *gorm.DB {
|
||||||
|
if unscoped {
|
||||||
|
return DB.Unscoped()
|
||||||
|
} else {
|
||||||
|
return DB
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -466,8 +466,9 @@ type Product4 struct {
|
|||||||
|
|
||||||
type ProductItem struct {
|
type ProductItem struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Code string
|
Code string
|
||||||
Product4ID uint
|
Product4ID uint
|
||||||
|
AfterFindCallTimes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
||||||
@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pi *ProductItem) AfterFind(*gorm.DB) error {
|
||||||
|
pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
||||||
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
|
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
|
||||||
DB.AutoMigrate(&Product4{}, &ProductItem{})
|
DB.AutoMigrate(&Product4{}, &ProductItem{})
|
||||||
@ -498,4 +504,65 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
|||||||
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
|
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
|
||||||
t.Errorf("should find product, but got error %v", err)
|
t.Errorf("should find product, but got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var productWithItem Product4
|
||||||
|
if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil {
|
||||||
|
t.Errorf("should find product, but got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if productWithItem.Item.AfterFindCallTimes != 0 {
|
||||||
|
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Product5 struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
var beforeUpdateCall int
|
||||||
|
|
||||||
|
func (p *Product5) BeforeUpdate(*gorm.DB) error {
|
||||||
|
beforeUpdateCall = beforeUpdateCall + 1
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateCallbacks(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&Product5{})
|
||||||
|
DB.AutoMigrate(&Product5{})
|
||||||
|
|
||||||
|
p := Product5{Name: "unique_code"}
|
||||||
|
DB.Model(&Product5{}).Create(&p)
|
||||||
|
|
||||||
|
err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("should update success, but got err %v", err)
|
||||||
|
}
|
||||||
|
if beforeUpdateCall != 1 {
|
||||||
|
t.Fatalf("before update should be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error
|
||||||
|
if !errors.Is(err, gorm.ErrInvalidValue) {
|
||||||
|
t.Fatalf("should got RecordNotFound, but got %v", err)
|
||||||
|
}
|
||||||
|
if beforeUpdateCall != 1 {
|
||||||
|
t.Fatalf("before update should not be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("should update success, but got err %v", err)
|
||||||
|
}
|
||||||
|
if beforeUpdateCall != 2 {
|
||||||
|
t.Fatalf("before update should be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error
|
||||||
|
if !errors.Is(err, gorm.ErrInvalidValue) {
|
||||||
|
t.Fatalf("should got RecordNotFound, but got %v", err)
|
||||||
|
}
|
||||||
|
if beforeUpdateCall != 2 {
|
||||||
|
t.Fatalf("before update should not be called")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -229,3 +229,174 @@ func TestJoinWithSoftDeleted(t *testing.T) {
|
|||||||
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInnerJoins(t *testing.T) {
|
||||||
|
user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
|
||||||
|
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
var err error
|
||||||
|
err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
// inner join and NamedPet is nil
|
||||||
|
err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error
|
||||||
|
AssertEqual(t, err, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
// mixed inner join and left join
|
||||||
|
var user3 User
|
||||||
|
err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
CheckUser(t, user3, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinWithSameColumnName(t *testing.T) {
|
||||||
|
user := GetUser("TestJoinWithSameColumnName", Config{
|
||||||
|
Languages: 1,
|
||||||
|
Pets: 1,
|
||||||
|
})
|
||||||
|
DB.Create(user)
|
||||||
|
type UserSpeak struct {
|
||||||
|
UserID uint
|
||||||
|
LanguageCode string
|
||||||
|
}
|
||||||
|
type Result struct {
|
||||||
|
User
|
||||||
|
UserSpeak
|
||||||
|
Language
|
||||||
|
Pet
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]Result, 0, 1)
|
||||||
|
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
|
||||||
|
Joins("JOIN languages ON languages.code = user_speaks.language_code").
|
||||||
|
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Fatalf("no record find")
|
||||||
|
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
|
||||||
|
t.Fatalf("wrong user id in pet")
|
||||||
|
} else if results[0].Pet.Name != user.Pets[0].Name {
|
||||||
|
t.Fatalf("wrong pet name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinArgsWithDB(t *testing.T) {
|
||||||
|
user := *GetUser("joins-args-db", Config{Pets: 2})
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
// test where
|
||||||
|
var user1 User
|
||||||
|
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
|
||||||
|
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
|
||||||
|
|
||||||
|
// test where and omit
|
||||||
|
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
|
||||||
|
var user2 User
|
||||||
|
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
|
||||||
|
AssertEqual(t, user2.NamedPet.Name, "")
|
||||||
|
|
||||||
|
// test where and select
|
||||||
|
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
|
||||||
|
var user3 User
|
||||||
|
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, user3.NamedPet.ID, 0)
|
||||||
|
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
|
||||||
|
|
||||||
|
// test select
|
||||||
|
onQuery4 := DB.Select("ID")
|
||||||
|
var user4 User
|
||||||
|
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||||
|
}
|
||||||
|
if user4.NamedPet.ID == 0 {
|
||||||
|
t.Fatal("Pet ID can not be empty")
|
||||||
|
}
|
||||||
|
AssertEqual(t, user4.NamedPet.Name, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedJoins(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{
|
||||||
|
Name: "nested-joins-1",
|
||||||
|
Manager: &User{
|
||||||
|
Name: "nested-joins-manager-1",
|
||||||
|
Company: Company{
|
||||||
|
Name: "nested-joins-manager-company-1",
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{
|
||||||
|
Name: "nested-joins-manager-namepet-1",
|
||||||
|
Toy: Toy{
|
||||||
|
Name: "nested-joins-manager-namepet-toy-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "nested-joins-2",
|
||||||
|
Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||||
|
NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var userIDs []uint
|
||||||
|
for _, user := range users {
|
||||||
|
userIDs = append(userIDs, user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users2 []User
|
||||||
|
if err := DB.
|
||||||
|
Joins("Manager").
|
||||||
|
Joins("Manager.Company").
|
||||||
|
Joins("Manager.NamedPet").
|
||||||
|
Joins("Manager.NamedPet.Toy").
|
||||||
|
Joins("NamedPet").
|
||||||
|
Joins("NamedPet.Toy").
|
||||||
|
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
} else if len(users2) != len(users) {
|
||||||
|
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(users2, func(i, j int) bool {
|
||||||
|
return users2[i].ID > users2[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(users, func(i, j int) bool {
|
||||||
|
return users[i].ID > users[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
for idx, user := range users {
|
||||||
|
// user
|
||||||
|
CheckUser(t, user, users2[idx])
|
||||||
|
if users2[idx].Manager == nil {
|
||||||
|
t.Fatalf("Failed to load Manager")
|
||||||
|
}
|
||||||
|
// manager
|
||||||
|
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||||
|
// user pet
|
||||||
|
if users2[idx].NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||||
|
// manager pet
|
||||||
|
if users2[idx].Manager.NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -7,8 +7,60 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Yasuo struct {
|
||||||
|
ID string `gorm:"default:gen_random_uuid()"`
|
||||||
|
Name string
|
||||||
|
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||||
|
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||||
|
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Yasuo{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
yasuo := Yasuo{Name: "jinzhu"}
|
||||||
|
if err := DB.Create(&yasuo).Error; err != nil {
|
||||||
|
t.Fatalf("should be able to create data, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if yasuo.ID == "" {
|
||||||
|
t.Fatal("should be able to has ID, but got zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Yasuo
|
||||||
|
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
yasuo.Name = "jinzhu1"
|
||||||
|
if err := DB.Save(&yasuo).Error; err != nil {
|
||||||
|
t.Errorf("Failed to update date, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPostgres(t *testing.T) {
|
func TestPostgres(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "postgres" {
|
||||||
t.Skip()
|
t.Skip()
|
||||||
@ -60,16 +112,55 @@ func TestPostgres(t *testing.T) {
|
|||||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||||
t.Errorf("No error should happen, but got %v", err)
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable("log_usage")
|
||||||
|
|
||||||
|
if err := DB.Exec(`
|
||||||
|
CREATE TABLE public.log_usage (
|
||||||
|
log_id bigint NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||||
|
SEQUENCE NAME public.log_usage_log_id_seq
|
||||||
|
START WITH 1
|
||||||
|
INCREMENT BY 1
|
||||||
|
NO MINVALUE
|
||||||
|
NO MAXVALUE
|
||||||
|
CACHE 1
|
||||||
|
);
|
||||||
|
`).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create table, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get columns, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasLogID := false
|
||||||
|
for _, column := range columns {
|
||||||
|
if column.Name() == "log_id" {
|
||||||
|
hasLogID = true
|
||||||
|
autoIncrement, ok := column.AutoIncrement()
|
||||||
|
if !ok || !autoIncrement {
|
||||||
|
t.Fatalf("column log_id should be auto incrementment")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasLogID {
|
||||||
|
t.Fatalf("failed to found column log_id")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Post struct {
|
type Post struct {
|
||||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||||
Title string
|
Title string
|
||||||
Categories []*Category `gorm:"Many2Many:post_categories"`
|
Categories []*Category `gorm:"Many2Many:post_categories"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||||
Title string
|
Title string
|
||||||
Posts []*Post `gorm:"Many2Many:post_categories"`
|
Posts []*Post `gorm:"Many2Many:post_categories"`
|
||||||
}
|
}
|
||||||
@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
|
|||||||
t.Errorf("Failed, got error: %v", err)
|
t.Errorf("Failed, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPostgresOnConstraint(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Thing struct {
|
||||||
|
gorm.Model
|
||||||
|
SomeID string
|
||||||
|
OtherID string
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Thing{})
|
||||||
|
DB.Migrator().CreateTable(&Thing{})
|
||||||
|
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
thing := Thing{
|
||||||
|
SomeID: "1234",
|
||||||
|
OtherID: "1234",
|
||||||
|
Data: "something",
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&thing)
|
||||||
|
|
||||||
|
thing2 := Thing{
|
||||||
|
SomeID: "1234",
|
||||||
|
OtherID: "1234",
|
||||||
|
Data: "something else",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := DB.Clauses(clause.OnConflict{
|
||||||
|
OnConstraint: "some_id_other_id_unique",
|
||||||
|
UpdateAll: true,
|
||||||
|
}).Create(&thing2)
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Errorf("creating second thing: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var things []Thing
|
||||||
|
if err := DB.Find(&things).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(things) > 1 {
|
||||||
|
t.Errorf("expected 1 thing got more")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompanyNew struct {
|
||||||
|
ID int
|
||||||
|
Name int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAlterColumnDataType(t *testing.T) {
|
||||||
|
DB.AutoMigrate(Company{})
|
||||||
|
|
||||||
|
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||||
|
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.AutoMigrate(Company{})
|
||||||
|
}
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -269,3 +271,242 @@ func TestPreloadWithDiffModel(t *testing.T) {
|
|||||||
|
|
||||||
CheckUser(t, user, result.User)
|
CheckUser(t, user, result.User)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedPreloadWithUnscoped(t *testing.T) {
|
||||||
|
user := *GetUser("nested_preload", Config{Pets: 1})
|
||||||
|
pet := user.Pets[0]
|
||||||
|
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)}
|
||||||
|
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)}
|
||||||
|
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID)
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
|
||||||
|
DB.Delete(&pet)
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID)
|
||||||
|
if len(user3.Pets) != 0 {
|
||||||
|
t.Fatalf("User.Pet[0] was deleted and should not exist.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user4 *User
|
||||||
|
DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID)
|
||||||
|
if len(user4.Pets) != 0 {
|
||||||
|
t.Fatalf("User.Pet[0] was deleted and should not exist.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user5 User
|
||||||
|
DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID)
|
||||||
|
CheckUserUnscoped(t, user5, user)
|
||||||
|
|
||||||
|
var user6 *User
|
||||||
|
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
||||||
|
CheckUserUnscoped(t, *user6, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Preload struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
NestedID uint
|
||||||
|
}
|
||||||
|
Join struct {
|
||||||
|
ID uint
|
||||||
|
Value string
|
||||||
|
NestedID uint
|
||||||
|
}
|
||||||
|
Nested struct {
|
||||||
|
ID uint
|
||||||
|
Preloads []*Preload
|
||||||
|
Join Join
|
||||||
|
ValueID uint
|
||||||
|
}
|
||||||
|
Value struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Nested Nested
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||||
|
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||||
|
|
||||||
|
value := Value{
|
||||||
|
Name: "value",
|
||||||
|
Nested: Nested{
|
||||||
|
Preloads: []*Preload{
|
||||||
|
{Value: "p1"}, {Value: "p2"},
|
||||||
|
},
|
||||||
|
Join: Join{Value: "j1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := DB.Create(&value).Error; err != nil {
|
||||||
|
t.Errorf("failed to create value, got err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var find1 Value
|
||||||
|
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, find1, value)
|
||||||
|
|
||||||
|
var find2 Value
|
||||||
|
// Joins will automatically add Nested queries.
|
||||||
|
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, find2, value)
|
||||||
|
|
||||||
|
var finds []Value
|
||||||
|
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find value, got err: %v", err)
|
||||||
|
}
|
||||||
|
require.Len(t, finds, 1)
|
||||||
|
AssertEqual(t, finds[0], value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbedPreload(t *testing.T) {
|
||||||
|
type Country struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type EmbeddedAddress struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
CountryID *int
|
||||||
|
Country *Country
|
||||||
|
}
|
||||||
|
type NestedAddress struct {
|
||||||
|
EmbeddedAddress
|
||||||
|
}
|
||||||
|
type Org struct {
|
||||||
|
ID int
|
||||||
|
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||||
|
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||||
|
AddressID *int
|
||||||
|
Address *EmbeddedAddress
|
||||||
|
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||||
|
|
||||||
|
org := Org{
|
||||||
|
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
|
||||||
|
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
|
||||||
|
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
|
||||||
|
NestedAddress: NestedAddress{
|
||||||
|
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := DB.Create(&org).Error; err != nil {
|
||||||
|
t.Errorf("failed to create org, got err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
preloads map[string][]interface{}
|
||||||
|
expect Org
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "address country",
|
||||||
|
preloads: map[string][]interface{}{"Address.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: org.Address,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "postal address country",
|
||||||
|
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: org.PostalAddress,
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||||
|
ID: org.NestedAddress.ID,
|
||||||
|
Name: org.NestedAddress.Name,
|
||||||
|
CountryID: org.NestedAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "nested address country",
|
||||||
|
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||||
|
expect: Org{
|
||||||
|
ID: org.ID,
|
||||||
|
PostalAddress: EmbeddedAddress{
|
||||||
|
ID: org.PostalAddress.ID,
|
||||||
|
Name: org.PostalAddress.Name,
|
||||||
|
CountryID: org.PostalAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
VisitingAddress: EmbeddedAddress{
|
||||||
|
ID: org.VisitingAddress.ID,
|
||||||
|
Name: org.VisitingAddress.Name,
|
||||||
|
CountryID: org.VisitingAddress.CountryID,
|
||||||
|
Country: nil,
|
||||||
|
},
|
||||||
|
AddressID: org.AddressID,
|
||||||
|
Address: nil,
|
||||||
|
NestedAddress: org.NestedAddress,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "associations",
|
||||||
|
preloads: map[string][]interface{}{
|
||||||
|
clause.Associations: {},
|
||||||
|
// clause.Associations won’t preload nested associations
|
||||||
|
"Address.Country": {},
|
||||||
|
},
|
||||||
|
expect: org,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
actual := Org{}
|
||||||
|
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
|
||||||
|
for name, args := range test.preloads {
|
||||||
|
tx = tx.Preload(name, args...)
|
||||||
|
}
|
||||||
|
if err := tx.Find(&actual).Error; err != nil {
|
||||||
|
t.Errorf("failed to find org, got err: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, actual, test.expect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -2,6 +2,8 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -88,3 +90,80 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
tx2.Commit()
|
tx2.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||||
|
tx, err := OpenTestConnection(&gorm.Config{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
sqlDB, _ := tx.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
tx.Create(&user)
|
||||||
|
|
||||||
|
var result User
|
||||||
|
tx.First(&result)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
AssertEqual(t, ok, true)
|
||||||
|
AssertEqual(t, len(conn.Stmts), 2)
|
||||||
|
for _, stmt := range conn.Stmts {
|
||||||
|
if stmt == nil {
|
||||||
|
t.Fatalf("stmt cannot bee nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
|
||||||
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
|
||||||
|
return errors.New("test")
|
||||||
|
}); err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.First(&result, user.ID).Error; err == nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtReset(t *testing.T) {
|
||||||
|
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||||
|
|
||||||
|
user := *GetUser("prepared_stmt_reset", Config{})
|
||||||
|
tx = tx.Create(&user)
|
||||||
|
|
||||||
|
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
pdb.Mux.Lock()
|
||||||
|
if len(pdb.Stmts) == 0 {
|
||||||
|
pdb.Mux.Unlock()
|
||||||
|
t.Fatalf("prepared stmt can not be empty")
|
||||||
|
}
|
||||||
|
pdb.Mux.Unlock()
|
||||||
|
|
||||||
|
pdb.Reset()
|
||||||
|
pdb.Mux.Lock()
|
||||||
|
defer pdb.Mux.Unlock()
|
||||||
|
if len(pdb.Stmts) != 0 {
|
||||||
|
t.Fatalf("prepared stmt should be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -216,6 +217,30 @@ func TestFind(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test array
|
||||||
|
var models2 [3]User
|
||||||
|
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
|
||||||
|
} else {
|
||||||
|
for idx, user := range users {
|
||||||
|
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
CheckUser(t, models2[idx], user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test smaller array
|
||||||
|
var models3 [2]User
|
||||||
|
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil {
|
||||||
|
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
|
||||||
|
} else {
|
||||||
|
for idx, user := range users[:2] {
|
||||||
|
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
CheckUser(t, models3[idx], user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var none []User
|
var none []User
|
||||||
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||||
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||||
@ -257,7 +282,7 @@ func TestFindInBatches(t *testing.T) {
|
|||||||
totalBatch int
|
totalBatch int
|
||||||
)
|
)
|
||||||
|
|
||||||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||||
totalBatch += batch
|
totalBatch += batch
|
||||||
|
|
||||||
if tx.RowsAffected != 2 {
|
if tx.RowsAffected != 2 {
|
||||||
@ -273,7 +298,7 @@ func TestFindInBatches(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Save(results).Error; err != nil {
|
if err := tx.Save(results).Error; err != nil {
|
||||||
t.Errorf("failed to save users, got error %v", err)
|
t.Fatalf("failed to save users, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -384,6 +409,13 @@ func TestFindInBatchesWithError(t *testing.T) {
|
|||||||
if totalBatch != 0 {
|
if totalBatch != 0 {
|
||||||
t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch)
|
t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
totalBatch += batch
|
||||||
|
return nil
|
||||||
|
}); result.Error != gorm.ErrPrimaryKeyRequired {
|
||||||
|
t.Fatal("expected errors to have occurred, but nothing happened")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFillSmallerStruct(t *testing.T) {
|
func TestFillSmallerStruct(t *testing.T) {
|
||||||
@ -522,6 +554,11 @@ func TestNot(t *testing.T) {
|
|||||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
|
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
|
||||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{})
|
||||||
|
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotWithAllFields(t *testing.T) {
|
func TestNotWithAllFields(t *testing.T) {
|
||||||
@ -627,6 +664,18 @@ func TestOrWithAllFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Int64 int64
|
||||||
|
|
||||||
|
func (v Int64) Value() (driver.Value, error) {
|
||||||
|
return v - 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Int64) Scan(v interface{}) error {
|
||||||
|
y := v.(int64)
|
||||||
|
*f = Int64(y + 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestPluck(t *testing.T) {
|
func TestPluck(t *testing.T) {
|
||||||
users := []*User{
|
users := []*User{
|
||||||
GetUser("pluck-user1", Config{}),
|
GetUser("pluck-user1", Config{}),
|
||||||
@ -654,6 +703,11 @@ func TestPluck(t *testing.T) {
|
|||||||
t.Errorf("got error when pluck id: %v", err)
|
t.Errorf("got error when pluck id: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ids2 []Int64
|
||||||
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil {
|
||||||
|
t.Errorf("got error when pluck id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
for idx, name := range names {
|
for idx, name := range names {
|
||||||
if name != users[idx].Name {
|
if name != users[idx].Name {
|
||||||
t.Errorf("Unexpected result on pluck name, got %+v", names)
|
t.Errorf("Unexpected result on pluck name, got %+v", names)
|
||||||
@ -666,6 +720,12 @@ func TestPluck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for idx, id := range ids2 {
|
||||||
|
if int(id) != int(users[idx].ID+1) {
|
||||||
|
t.Errorf("Unexpected result on pluck id, got %+v", ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil {
|
||||||
t.Errorf("got error when pluck time: %v", err)
|
t.Errorf("got error when pluck time: %v", err)
|
||||||
@ -1063,12 +1123,12 @@ func TestSearchWithStruct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
|
||||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
|
||||||
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1258,3 +1318,113 @@ func TestQueryScannerWithSingleColumn(t *testing.T) {
|
|||||||
|
|
||||||
AssertEqual(t, result2.data, 20)
|
AssertEqual(t, result2.data, 20)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryResetNullValue(t *testing.T) {
|
||||||
|
type QueryResetItem struct {
|
||||||
|
ID string `gorm:"type:varchar(5)"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryResetNullValue struct {
|
||||||
|
ID int
|
||||||
|
Name string `gorm:"default:NULL"`
|
||||||
|
Flag bool `gorm:"default:NULL"`
|
||||||
|
Number1 int64 `gorm:"default:NULL"`
|
||||||
|
Number2 uint64 `gorm:"default:NULL"`
|
||||||
|
Number3 float64 `gorm:"default:NULL"`
|
||||||
|
Now *time.Time `gorm:"defalut:NULL"`
|
||||||
|
Item1Id string
|
||||||
|
Item1 *QueryResetItem `gorm:"references:ID"`
|
||||||
|
Item2Id string
|
||||||
|
Item2 *QueryResetItem `gorm:"references:ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{})
|
||||||
|
DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{})
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
q1 := QueryResetNullValue{
|
||||||
|
Name: "name",
|
||||||
|
Flag: true,
|
||||||
|
Number1: 100,
|
||||||
|
Number2: 200,
|
||||||
|
Number3: 300.1,
|
||||||
|
Now: &now,
|
||||||
|
Item1: &QueryResetItem{
|
||||||
|
ID: "u_1_1",
|
||||||
|
Name: "item_1_1",
|
||||||
|
},
|
||||||
|
Item2: &QueryResetItem{
|
||||||
|
ID: "u_1_2",
|
||||||
|
Name: "item_1_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
q2 := QueryResetNullValue{
|
||||||
|
Item1: &QueryResetItem{
|
||||||
|
ID: "u_2_1",
|
||||||
|
Name: "item_2_1",
|
||||||
|
},
|
||||||
|
Item2: &QueryResetItem{
|
||||||
|
ID: "u_2_2",
|
||||||
|
Name: "item_2_2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
err = DB.Create(&q1).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Create(&q2).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var qs []QueryResetNullValue
|
||||||
|
err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to find:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(qs) != 2 {
|
||||||
|
t.Fatalf("find count not equal:%d", len(qs))
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, q1, qs[0])
|
||||||
|
AssertEqual(t, q2, qs[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryError(t *testing.T) {
|
||||||
|
type P struct{}
|
||||||
|
var p1 P
|
||||||
|
err := DB.Take(&p1, 1).Error
|
||||||
|
AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired)
|
||||||
|
|
||||||
|
var p2 interface{}
|
||||||
|
|
||||||
|
err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{
|
||||||
|
Table: clause.CurrentTable, Name: clause.PrimaryKey,
|
||||||
|
}, Value: 1}).Scan(&p2).Error
|
||||||
|
AssertEqual(t, err, gorm.ErrModelValueRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryScanToArray(t *testing.T) {
|
||||||
|
err := DB.Create(&User{Name: "testname1", Age: 10}).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := [2]*User{{Name: "1"}, {Name: "2"}}
|
||||||
|
err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if users[0] == nil || users[0].Name != "testname1" {
|
||||||
|
t.Error("users[0] not covere")
|
||||||
|
}
|
||||||
|
if users[1] != nil {
|
||||||
|
t.Error("users[1] should be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -214,4 +214,29 @@ func TestScanToEmbedded(t *testing.T) {
|
|||||||
if !addressMatched {
|
if !addressMatched {
|
||||||
t.Errorf("Failed, no address matched")
|
t.Errorf("Failed, no address matched")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
personDupField := Person{ID: person1.ID}
|
||||||
|
if err := DB.Select("people.id, people.*").
|
||||||
|
First(&personDupField).Error; err != nil {
|
||||||
|
t.Errorf("Failed to run join query, got error: %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, person1, personDupField)
|
||||||
|
|
||||||
|
user := User{
|
||||||
|
Name: "TestScanToEmbedded_1",
|
||||||
|
Manager: &User{
|
||||||
|
Name: "TestScanToEmbedded_1_m1",
|
||||||
|
Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
type UserScan struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
ManagerID *uint
|
||||||
|
}
|
||||||
|
var user2 UserScan
|
||||||
|
err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
|
|||||||
return errors.New("Too short")
|
return errors.New("Too short")
|
||||||
}
|
}
|
||||||
|
|
||||||
*data = b[3:]
|
*data = append((*data)[0:], b[3:]...)
|
||||||
return nil
|
return nil
|
||||||
} else if s, ok := value.(string); ok {
|
} else if s, ok := value.(string); ok {
|
||||||
*data = []byte(s)[3:]
|
*data = []byte(s[3:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -72,3 +72,58 @@ func TestScopes(t *testing.T) {
|
|||||||
t.Errorf("select max(id)")
|
t.Errorf("select max(id)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestComplexScopes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryFn func(tx *gorm.DB) *gorm.DB
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "depth_1",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.Where(DB.Or("b = 2").Or("c = 3"))
|
||||||
|
},
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
|
||||||
|
}, {
|
||||||
|
name: "depth_1_pre_cond",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Where("z = 0").Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.Or(DB.Where("b = 2").Or("c = 3"))
|
||||||
|
},
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
|
||||||
|
}, {
|
||||||
|
name: "depth_2",
|
||||||
|
queryFn: func(tx *gorm.DB) *gorm.DB {
|
||||||
|
return tx.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
|
||||||
|
func(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.
|
||||||
|
Or(DB.Scopes(
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
|
||||||
|
)).
|
||||||
|
Or("c = 3")
|
||||||
|
},
|
||||||
|
func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") },
|
||||||
|
).Find(&Language{})
|
||||||
|
},
|
||||||
|
expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -16,13 +16,40 @@ import (
|
|||||||
|
|
||||||
type SerializerStruct struct {
|
type SerializerStruct struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name []byte `gorm:"json"`
|
Name []byte `gorm:"json"`
|
||||||
Roles Roles `gorm:"serializer:json"`
|
Roles Roles `gorm:"serializer:json"`
|
||||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
Roles2 *Roles `gorm:"serializer:json"`
|
||||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||||
UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||||
EncryptedString EncryptedString
|
CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type
|
||||||
|
UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type
|
||||||
|
CustomSerializerString string `gorm:"serializer:custom"`
|
||||||
|
EncryptedString EncryptedString
|
||||||
|
}
|
||||||
|
|
||||||
|
type SerializerPostgresStruct struct {
|
||||||
|
gorm.Model
|
||||||
|
Name []byte `gorm:"json"`
|
||||||
|
Roles Roles `gorm:"serializer:json"`
|
||||||
|
Roles2 *Roles `gorm:"serializer:json"`
|
||||||
|
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||||
|
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||||
|
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||||
|
CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type
|
||||||
|
UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type
|
||||||
|
CustomSerializerString string `gorm:"serializer:custom"`
|
||||||
|
EncryptedString EncryptedString
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
|
||||||
|
|
||||||
|
func adaptorSerializerModel(s *SerializerStruct) interface{} {
|
||||||
|
if DB.Dialector.Name() == "postgres" {
|
||||||
|
sps := SerializerPostgresStruct(*s)
|
||||||
|
return &sps
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
type Roles []string
|
type Roles []string
|
||||||
@ -52,9 +79,34 @@ func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst re
|
|||||||
return "hello" + string(es), nil
|
return "hello" + string(es), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CustomSerializer struct {
|
||||||
|
prefix []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCustomSerializer(prefix string) *CustomSerializer {
|
||||||
|
return &CustomSerializer{prefix: []byte(prefix)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
switch value := dbValue.(type) {
|
||||||
|
case []byte:
|
||||||
|
err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix))
|
||||||
|
case string:
|
||||||
|
err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix)))
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unsupported data %#v", dbValue)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
|
return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestSerializer(t *testing.T) {
|
func TestSerializer(t *testing.T) {
|
||||||
DB.Migrator().DropTable(&SerializerStruct{})
|
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||||
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||||
|
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||||
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,12 +126,42 @@ func TestSerializer(t *testing.T) {
|
|||||||
Location: "Kenmawr",
|
Location: "Kenmawr",
|
||||||
IsIntern: false,
|
IsIntern: false,
|
||||||
},
|
},
|
||||||
|
CustomSerializerString: "world",
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err != nil {
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
t.Fatalf("failed to create data, got error %v", err)
|
t.Fatalf("failed to create data, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var result SerializerStruct
|
||||||
|
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result, data)
|
||||||
|
|
||||||
|
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||||
|
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializerZeroValue(t *testing.T) {
|
||||||
|
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||||
|
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||||
|
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||||
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := SerializerStruct{}
|
||||||
|
|
||||||
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var result SerializerStruct
|
var result SerializerStruct
|
||||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
t.Fatalf("failed to query data, got error %v", err)
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
@ -87,11 +169,19 @@ func TestSerializer(t *testing.T) {
|
|||||||
|
|
||||||
AssertEqual(t, result, data)
|
AssertEqual(t, result, data)
|
||||||
|
|
||||||
|
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||||
|
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||||
DB.Migrator().DropTable(&SerializerStruct{})
|
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||||
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
|
||||||
|
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
|
||||||
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +199,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
|||||||
Location: "Shadyside",
|
Location: "Shadyside",
|
||||||
IsIntern: false,
|
IsIntern: false,
|
||||||
},
|
},
|
||||||
|
CustomSerializerString: "world",
|
||||||
}
|
}
|
||||||
|
|
||||||
// first time insert record
|
// first time insert record
|
||||||
@ -123,7 +214,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
AssertEqual(t, result, out)
|
AssertEqual(t, result, out)
|
||||||
|
|
||||||
//update record
|
// update record
|
||||||
data.Roles = append(data.Roles, "r3")
|
data.Roles = append(data.Roles, "r3")
|
||||||
data.JobInfo.Location = "Gates Hillman Complex"
|
data.JobInfo.Location = "Gates Hillman Complex"
|
||||||
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/now"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
@ -39,6 +40,11 @@ func TestSoftDelete(t *testing.T) {
|
|||||||
t.Fatalf("invalid sql generated, got %v", sql)
|
t.Fatalf("invalid sql generated, got %v", sql)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String()
|
||||||
|
if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) {
|
||||||
|
t.Errorf("Table with escape character, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||||
t.Errorf("Can't find a soft deleted record")
|
t.Errorf("Can't find a soft deleted record")
|
||||||
}
|
}
|
||||||
@ -93,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) {
|
|||||||
t.Fatalf("invalid sql generated, got %v", actualSQL)
|
t.Fatalf("invalid sql generated, got %v", actualSQL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoftDeleteZeroValue(t *testing.T) {
|
||||||
|
type SoftDeleteBook struct {
|
||||||
|
ID uint
|
||||||
|
Name string
|
||||||
|
Pages uint
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"`
|
||||||
|
}
|
||||||
|
DB.Migrator().DropTable(&SoftDeleteBook{})
|
||||||
|
if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil {
|
||||||
|
t.Fatalf("failed to auto migrate soft delete table")
|
||||||
|
}
|
||||||
|
|
||||||
|
book := SoftDeleteBook{Name: "jinzhu", Pages: 10}
|
||||||
|
DB.Save(&book)
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||||
|
t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pages uint
|
||||||
|
if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||||
|
t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Delete(&book).Error; err != nil {
|
||||||
|
t.Fatalf("No error should happen when soft delete user, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
zeroTime, _ := now.Parse("1970-01-01 00:00:01")
|
||||||
|
if book.DeletedAt.Time.Equal(zeroTime) {
|
||||||
|
t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil {
|
||||||
|
t.Errorf("Can't find a soft deleted record")
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 {
|
||||||
|
t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
pages = 0
|
||||||
|
if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 {
|
||||||
|
t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 {
|
||||||
|
t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
pages = 0
|
||||||
|
if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages {
|
||||||
|
t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Unscoped().Delete(&book)
|
||||||
|
if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
t.Errorf("Can't find permanently deleted record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user