Merge branch 'go-gorm:master' into master
This commit is contained in:
commit
e0030749b6
12
.github/workflows/invalid_question.yml
vendored
12
.github/workflows/invalid_question.yml
vendored
@ -3,20 +3,26 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: "*/10 * * * *"
|
- cron: "*/10 * * * *"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
stale:
|
stale:
|
||||||
|
permissions:
|
||||||
|
issues: write # for actions/stale to close stale issues
|
||||||
|
pull-requests: write # for actions/stale to close stale PRs
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v4
|
uses: actions/stale@v5
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
stale-issue-label: "status:stale"
|
stale-issue-label: "status:stale"
|
||||||
days-before-stale: 0
|
days-before-stale: 0
|
||||||
days-before-close: 2
|
days-before-close: 30
|
||||||
remove-stale-when-updated: true
|
remove-stale-when-updated: true
|
||||||
only-labels: "type:invalid question"
|
only-labels: "type:invalid question"
|
||||||
|
|
||||||
|
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
|||||||
name: Label issues and pull requests
|
name: Label issues and pull requests
|
||||||
steps:
|
steps:
|
||||||
- name: check out
|
- name: check out
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: labeler
|
- name: labeler
|
||||||
uses: jinzhu/super-labeler-action@develop
|
uses: jinzhu/super-labeler-action@develop
|
||||||
|
12
.github/workflows/missing_playground.yml
vendored
12
.github/workflows/missing_playground.yml
vendored
@ -3,19 +3,25 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: "*/10 * * * *"
|
- cron: "*/10 * * * *"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
stale:
|
stale:
|
||||||
|
permissions:
|
||||||
|
issues: write # for actions/stale to close stale issues
|
||||||
|
pull-requests: write # for actions/stale to close stale PRs
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v4
|
uses: actions/stale@v5
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
stale-issue-label: "status:stale"
|
stale-issue-label: "status:stale"
|
||||||
days-before-stale: 0
|
days-before-stale: 0
|
||||||
days-before-close: 2
|
days-before-close: 30
|
||||||
remove-stale-when-updated: true
|
remove-stale-when-updated: true
|
||||||
only-labels: "type:missing reproduction steps"
|
only-labels: "type:missing reproduction steps"
|
||||||
|
2
.github/workflows/reviewdog.yml
vendored
2
.github/workflows/reviewdog.yml
vendored
@ -6,7 +6,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: reviewdog/action-golangci-lint@v2
|
uses: reviewdog/action-golangci-lint@v2
|
||||||
|
|
||||||
|
14
.github/workflows/stale.yml
vendored
14
.github/workflows/stale.yml
vendored
@ -3,19 +3,25 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: "0 2 * * *"
|
- cron: "0 2 * * *"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
stale:
|
stale:
|
||||||
|
permissions:
|
||||||
|
issues: write # for actions/stale to close stale issues
|
||||||
|
pull-requests: write # for actions/stale to close stale PRs
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v4
|
uses: actions/stale@v5
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
|
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||||
days-before-stale: 60
|
days-before-stale: 360
|
||||||
days-before-close: 30
|
days-before-close: 180
|
||||||
stale-issue-label: "status:stale"
|
stale-issue-label: "status:stale"
|
||||||
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
|
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
|
||||||
stale-pr-label: 'status:stale'
|
stale-pr-label: 'status:stale'
|
||||||
|
43
.github/workflows/tests.yml
vendored
43
.github/workflows/tests.yml
vendored
@ -8,38 +8,41 @@ on:
|
|||||||
branches-ignore:
|
branches-ignore:
|
||||||
- 'gh-pages'
|
- 'gh-pages'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# Label of the container job
|
# Label of the container job
|
||||||
sqlite:
|
sqlite:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.17', '1.16']
|
go: ['1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run in windows OS
|
platform: [ubuntu-latest] # can not run in windows OS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v3
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: GORM_DIALECT=sqlite ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
|
||||||
|
|
||||||
mysql:
|
mysql:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||||
go: ['1.17', '1.16']
|
go: ['1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest]
|
platform: [ubuntu-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -62,28 +65,28 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v3
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||||
|
|
||||||
postgres:
|
postgres:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||||
go: ['1.17', '1.16']
|
go: ['1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -106,26 +109,26 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v3
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
|
||||||
|
|
||||||
sqlserver:
|
sqlserver:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.17', '1.16']
|
go: ['1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -149,18 +152,18 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v3
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
|
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: go mod package cache
|
- name: go mod package cache
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ documents
|
|||||||
coverage.txt
|
coverage.txt
|
||||||
_book
|
_book
|
||||||
.idea
|
.idea
|
||||||
|
vendor
|
@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
|||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||||
|
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
primaryFields []*schema.Field
|
primaryFields []*schema.Field
|
||||||
foreignKeys []string
|
foreignKeys []string
|
||||||
updateMap = map[string]interface{}{}
|
updateMap = map[string]interface{}{}
|
||||||
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
|
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
|
||||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
tx = association.DB.Model(modelValue)
|
tx = association.DB.Model(modelValue)
|
||||||
)
|
)
|
||||||
|
|
||||||
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
|
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
|
||||||
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
|
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
|
||||||
tx.Not(clause.IN{Column: column, Values: values})
|
tx.Not(clause.IN{Column: column, Values: values})
|
||||||
}
|
}
|
||||||
@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
|
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||||
}
|
}
|
||||||
@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||||
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
||||||
tx.Where(clause.IN{Column: column, Values: values})
|
tx.Where(clause.IN{Column: column, Values: values})
|
||||||
} else {
|
} else {
|
||||||
return ErrPrimaryKeyRequired
|
return ErrPrimaryKeyRequired
|
||||||
}
|
}
|
||||||
|
|
||||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||||
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||||
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||||
}
|
}
|
||||||
@ -186,11 +186,14 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
case schema.BelongsTo:
|
case schema.BelongsTo:
|
||||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||||
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
|
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||||
|
} else {
|
||||||
|
return ErrPrimaryKeyRequired
|
||||||
|
}
|
||||||
|
|
||||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
|
||||||
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
@ -198,11 +201,14 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||||
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||||
|
} else {
|
||||||
|
return ErrPrimaryKeyRequired
|
||||||
|
}
|
||||||
|
|
||||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
@ -228,11 +234,14 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||||
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
|
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||||
|
} else {
|
||||||
|
return ErrPrimaryKeyRequired
|
||||||
|
}
|
||||||
|
|
||||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||||
|
|
||||||
@ -241,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
|
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
// clean up deleted values's foreign key
|
// clean up deleted values's foreign key
|
||||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||||
|
|
||||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||||
if _, zero := rel.Field.ValueOf(data); !zero {
|
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
|
||||||
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
|
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
|
||||||
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
|
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
|
||||||
|
|
||||||
switch fieldValue.Kind() {
|
switch fieldValue.Kind() {
|
||||||
@ -253,7 +262,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
|
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
|
||||||
for i := 0; i < fieldValue.Len(); i++ {
|
for i := 0; i < fieldValue.Len(); i++ {
|
||||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||||
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
|
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
|
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
|
||||||
@ -261,23 +270,23 @@ func (association *Association) Delete(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
association.Error = rel.Field.Set(data, validFieldValues.Interface())
|
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||||
primaryValues[idx], _ = field.ValueOf(fieldValue)
|
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
|
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
|
||||||
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
|
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if rel.JoinTable == nil {
|
if rel.JoinTable == nil {
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
|
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
|
||||||
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||||
} else {
|
} else {
|
||||||
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -329,14 +338,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
switch rv.Kind() {
|
switch rv.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if rv.Len() > 0 {
|
if rv.Len() > 0 {
|
||||||
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
|
||||||
|
|
||||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
|
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
|
||||||
|
|
||||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
|
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
|
||||||
@ -344,7 +353,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
}
|
}
|
||||||
case schema.HasMany, schema.Many2Many:
|
case schema.HasMany, schema.Many2Many:
|
||||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
|
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||||
if clear {
|
if clear {
|
||||||
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
||||||
}
|
}
|
||||||
@ -373,7 +382,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if association.Error == nil {
|
if association.Error == nil {
|
||||||
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -421,7 +430,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
// clear old data
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||||
association.Error = err
|
association.Error = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -429,7 +438,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
if association.Relationship.JoinTable == nil {
|
if association.Relationship.JoinTable == nil {
|
||||||
for _, ref := range association.Relationship.References {
|
for _, ref := range association.Relationship.References {
|
||||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||||
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
|
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
|
||||||
association.Error = err
|
association.Error = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -453,12 +462,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
// clear old data
|
// clear old data
|
||||||
if clear && len(values) == 0 {
|
if clear && len(values) == 0 {
|
||||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||||
|
|
||||||
if association.Relationship.JoinTable == nil && association.Error == nil {
|
if association.Relationship.JoinTable == nil && association.Error == nil {
|
||||||
for _, ref := range association.Relationship.References {
|
for _, ref := range association.Relationship.References {
|
||||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||||
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -475,7 +484,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, assignBack := range assignBacks {
|
for _, assignBack := range assignBacks {
|
||||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
|
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
|
||||||
if assignBack.Index > 0 {
|
if assignBack.Index > 0 {
|
||||||
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
|
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
|
||||||
} else {
|
} else {
|
||||||
@ -486,14 +495,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
|||||||
|
|
||||||
func (association *Association) buildCondition() *DB {
|
func (association *Association) buildCondition() *DB {
|
||||||
var (
|
var (
|
||||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
|
||||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||||
tx = association.DB.Model(modelValue)
|
tx = association.DB.Model(modelValue)
|
||||||
)
|
)
|
||||||
|
|
||||||
if association.Relationship.JoinTable != nil {
|
if association.Relationship.JoinTable != nil {
|
||||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||||
joinStmt.AddClause(queryClause)
|
joinStmt.AddClause(queryClause)
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
|||||||
sortCallback func(*callback) error
|
sortCallback func(*callback) error
|
||||||
)
|
)
|
||||||
sort.Slice(cs, func(i, j int) bool {
|
sort.Slice(cs, func(i, j int) bool {
|
||||||
return cs[j].before == "*" || cs[j].after == "*"
|
if cs[j].before == "*" && cs[i].before != "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if cs[j].after == "*" && cs[i].after != "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
|
@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||||||
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if !ref.OwnPrimaryKey {
|
if !ref.OwnPrimaryKey {
|
||||||
pv, _ := ref.PrimaryKey.ValueOf(elem)
|
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||||
db.AddError(ref.ForeignKey.Set(obj, pv))
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
|
||||||
|
|
||||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||||
dest[ref.ForeignKey.DBName] = pv
|
dest[ref.ForeignKey.DBName] = pv
|
||||||
@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||||
objs = append(objs, obj)
|
objs = append(objs, obj)
|
||||||
if isPtr {
|
if isPtr {
|
||||||
elems = reflect.Append(elems, rv)
|
elems = reflect.Append(elems, rv)
|
||||||
@ -69,20 +69,20 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if elems.Len() > 0 {
|
if elems.Len() > 0 {
|
||||||
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||||
for i := 0; i < elems.Len(); i++ {
|
for i := 0; i < elems.Len(); i++ {
|
||||||
setupReferences(objs[i], elems.Index(i))
|
setupReferences(objs[i], elems.Index(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||||
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
|
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
|
||||||
if rv.Kind() != reflect.Ptr {
|
if rv.Kind() != reflect.Ptr {
|
||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
|
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||||
setupReferences(db.Statement.ReflectValue, rv)
|
setupReferences(db.Statement.ReflectValue, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
obj := db.Statement.ReflectValue.Index(i)
|
obj := db.Statement.ReflectValue.Index(i)
|
||||||
|
|
||||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||||
if _, zero := rel.Field.ValueOf(obj); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
|
||||||
rv := rel.Field.ReflectValueOf(obj)
|
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
|
||||||
if rv.Kind() != reflect.Ptr {
|
if rv.Kind() != reflect.Ptr {
|
||||||
rv = rv.Addr()
|
rv = rv.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||||
db.AddError(ref.ForeignKey.Set(rv, fv))
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
|
||||||
} else if ref.PrimaryValue != "" {
|
} else if ref.PrimaryValue != "" {
|
||||||
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,11 +146,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
if f.Kind() != reflect.Ptr {
|
if f.Kind() != reflect.Ptr {
|
||||||
f = f.Addr()
|
f = f.Addr()
|
||||||
}
|
}
|
||||||
@ -158,15 +158,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
assignmentColumns := make([]string, 0, len(rel.References))
|
assignmentColumns := make([]string, 0, len(rel.References))
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
|
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
ref.ForeignKey.Set(f, fv)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
|
||||||
} else if ref.PrimaryValue != "" {
|
} else if ref.PrimaryValue != "" {
|
||||||
ref.ForeignKey.Set(f, ref.PrimaryValue)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
|
||||||
}
|
}
|
||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||||
identityMap := map[string]bool{}
|
identityMap := map[string]bool{}
|
||||||
appendToElems := func(v reflect.Value) {
|
appendToElems := func(v reflect.Value) {
|
||||||
if _, zero := rel.Field.ValueOf(v); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||||
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
|
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||||
|
|
||||||
for i := 0; i < f.Len(); i++ {
|
for i := 0; i < f.Len(); i++ {
|
||||||
elem := f.Index(i)
|
elem := f.Index(i)
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
pv, _ := ref.PrimaryKey.ValueOf(v)
|
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
|
||||||
ref.ForeignKey.Set(elem, pv)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
|
||||||
} else if ref.PrimaryValue != "" {
|
} else if ref.PrimaryValue != "" {
|
||||||
ref.ForeignKey.Set(elem, ref.PrimaryValue)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||||
if pfv, ok := pf.ValueOf(elem); !ok {
|
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
joinValue := reflect.New(rel.JoinTable.ModelType)
|
joinValue := reflect.New(rel.JoinTable.ModelType)
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||||
ref.ForeignKey.Set(joinValue, fv)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||||
} else if ref.PrimaryValue != "" {
|
} else if ref.PrimaryValue != "" {
|
||||||
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
|
||||||
} else {
|
} else {
|
||||||
fv, _ := ref.PrimaryKey.ValueOf(elem)
|
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||||
ref.ForeignKey.Set(joinValue, fv)
|
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
joins = reflect.Append(joins, joinValue)
|
joins = reflect.Append(joins, joinValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
appendToElems := func(v reflect.Value) {
|
appendToElems := func(v reflect.Value) {
|
||||||
if _, zero := rel.Field.ValueOf(v); !zero {
|
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||||
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
|
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||||
|
|
||||||
for i := 0; i < f.Len(); i++ {
|
for i := 0; i < f.Len(); i++ {
|
||||||
elem := f.Index(i)
|
elem := f.Index(i)
|
||||||
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
// optimize elems of reflect value length
|
// optimize elems of reflect value length
|
||||||
if elemLen := elems.Len(); elemLen > 0 {
|
if elemLen := elems.Len(); elemLen > 0 {
|
||||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
|
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < elemLen; i++ {
|
for i := 0; i < elemLen; i++ {
|
||||||
@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
|
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
|
||||||
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
|
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
|
||||||
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
||||||
for _, dbName := range s.PrimaryFieldDBNames {
|
for _, dbName := range s.PrimaryFieldDBNames {
|
||||||
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||||
|
// stop save association loop
|
||||||
|
if checkAssociationsSaved(db, rValues) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
selects, omits []string
|
selects, omits []string
|
||||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
|
||||||
refName = rel.Name + "."
|
refName = rel.Name + "."
|
||||||
|
values = rValues.Interface()
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, ok := range selectColumns {
|
for name, ok := range selectColumns {
|
||||||
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
|||||||
|
|
||||||
return db.AddError(tx.Create(values).Error)
|
return db.AddError(tx.Create(values).Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check association values has been saved
|
||||||
|
// if values kind is Struct, check it has been saved
|
||||||
|
// if values kind is Slice/Array, check all items have been saved
|
||||||
|
var visitMapStoreKey = "gorm:saved_association_map"
|
||||||
|
|
||||||
|
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||||
|
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||||
|
if v, ok := visit.(*visitMap); ok {
|
||||||
|
if loadOrStoreVisitMap(v, values) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vistMap := make(visitMap)
|
||||||
|
loadOrStoreVisitMap(&vistMap, values)
|
||||||
|
db.Set(visitMapStoreKey, &vistMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// BeforeCreate before create hooks
|
||||||
func BeforeCreate(db *gorm.DB) {
|
func BeforeCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create create hook
|
||||||
func Create(config *Config) func(db *gorm.DB) {
|
func Create(config *Config) func(db *gorm.DB) {
|
||||||
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
||||||
|
|
||||||
@ -82,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||||
)
|
)
|
||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
gorm.Scan(rows, db, mode)
|
defer func() {
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
|
}()
|
||||||
|
gorm.Scan(rows, db, mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -117,9 +121,9 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -130,38 +134,39 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue)
|
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
if isZero {
|
if isZero {
|
||||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AfterCreate after create hooks
|
||||||
func AfterCreate(db *gorm.DB) {
|
func AfterCreate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
if db.Statement.Schema.AfterSave {
|
|
||||||
if i, ok := value.(AfterSaveInterface); ok {
|
|
||||||
called = true
|
|
||||||
db.AddError(i.AfterSave(tx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema.AfterCreate {
|
if db.Statement.Schema.AfterCreate {
|
||||||
if i, ok := value.(AfterCreateInterface); ok {
|
if i, ok := value.(AfterCreateInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
db.AddError(i.AfterCreate(tx))
|
db.AddError(i.AfterCreate(tx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Statement.Schema.AfterSave {
|
||||||
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
|
called = true
|
||||||
|
db.AddError(i.AfterSave(tx))
|
||||||
|
}
|
||||||
|
}
|
||||||
return called
|
return called
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -201,13 +206,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
switch stmt.ReflectValue.Kind() {
|
switch stmt.ReflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
rValLen := stmt.ReflectValue.Len()
|
rValLen := stmt.ReflectValue.Len()
|
||||||
stmt.SQL.Grow(rValLen * 18)
|
|
||||||
values.Values = make([][]interface{}, rValLen)
|
|
||||||
if rValLen == 0 {
|
if rValLen == 0 {
|
||||||
stmt.AddError(gorm.ErrEmptySlice)
|
stmt.AddError(gorm.ErrEmptySlice)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt.SQL.Grow(rValLen * 18)
|
||||||
|
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||||
|
values.Values = make([][]interface{}, rValLen)
|
||||||
|
|
||||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||||
for i := 0; i < rValLen; i++ {
|
for i := 0; i < rValLen; i++ {
|
||||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||||
@ -219,23 +226,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||||
for idx, column := range values.Columns {
|
for idx, column := range values.Columns {
|
||||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||||
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
|
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||||
if field.DefaultValueInterface != nil {
|
if field.DefaultValueInterface != nil {
|
||||||
values.Values[i][idx] = field.DefaultValueInterface
|
values.Values[i][idx] = field.DefaultValueInterface
|
||||||
field.Set(rv, field.DefaultValueInterface)
|
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||||
field.Set(rv, curTime)
|
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||||
}
|
}
|
||||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||||
field.Set(rv, curTime)
|
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
|
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||||
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||||
}
|
}
|
||||||
@ -259,23 +266,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
|||||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||||
for idx, column := range values.Columns {
|
for idx, column := range values.Columns {
|
||||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
|
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
|
||||||
if field.DefaultValueInterface != nil {
|
if field.DefaultValueInterface != nil {
|
||||||
values.Values[0][idx] = field.DefaultValueInterface
|
values.Values[0][idx] = field.DefaultValueInterface
|
||||||
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
|
||||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||||
field.Set(stmt.ReflectValue, curTime)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||||
}
|
}
|
||||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||||
field.Set(stmt.ReflectValue, curTime)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
|
|||||||
|
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.HasOne, schema.HasMany:
|
case schema.HasOne, schema.HasMany:
|
||||||
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
|
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
|
||||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||||
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
||||||
withoutConditions := false
|
withoutConditions := false
|
||||||
@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
|
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
|
||||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
@ -118,12 +118,18 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.Len() == 0 {
|
if db.Statement.SQL.Len() == 0 {
|
||||||
db.Statement.SQL.Grow(100)
|
db.Statement.SQL.Grow(100)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||||
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
@ -131,7 +137,7 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
|
||||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
|
||||||
db.Statement.AddClause(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.SQL.Len() == 0 {
|
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
db.Statement.Build(db.Statement.BuildClauses...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
|
checkMissingWhereConditions(db)
|
||||||
db.AddError(gorm.ErrMissingWhereClause)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
ok, mode := hasReturning(db, supportReturning)
|
ok, mode := hasReturning(db, supportReturning)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package callbacks
|
package callbacks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -104,3 +105,48 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
|||||||
}
|
}
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkMissingWhereConditions(db *gorm.DB) {
|
||||||
|
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||||
|
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||||
|
if withCondition {
|
||||||
|
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||||
|
whereClause, _ := where.Expression.(clause.Where)
|
||||||
|
withCondition = len(whereClause.Exprs) > 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !withCondition {
|
||||||
|
db.AddError(gorm.ErrMissingWhereClause)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type visitMap = map[reflect.Value]bool
|
||||||
|
|
||||||
|
// Check if circular values, return true if loaded
|
||||||
|
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
loaded = true
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
|
||||||
|
loaded = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Struct, reflect.Interface:
|
||||||
|
if v.CanAddr() {
|
||||||
|
p := v.Addr()
|
||||||
|
if _, ok := (*visitMap)[p]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
(*visitMap)[p] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -10,10 +10,9 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
|
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||||
var (
|
var (
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = tx.Statement.ReflectValue
|
||||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
|
|
||||||
relForeignKeys []string
|
relForeignKeys []string
|
||||||
relForeignFields []*schema.Field
|
relForeignFields []*schema.Field
|
||||||
foreignFields []*schema.Field
|
foreignFields []*schema.Field
|
||||||
@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
inlineConds []interface{}
|
inlineConds []interface{}
|
||||||
)
|
)
|
||||||
|
|
||||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
|
||||||
tx.Statement.Settings.Store(k, v)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
if rel.JoinTable != nil {
|
if rel.JoinTable != nil {
|
||||||
var (
|
var (
|
||||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||||
@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||||
if len(joinForeignValues) == 0 {
|
if len(joinForeignValues) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
|
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// convert join identity map to relation identity map
|
// convert join identity map to relation identity map
|
||||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||||
@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
for i := 0; i < joinResults.Len(); i++ {
|
for i := 0; i < joinResults.Len(); i++ {
|
||||||
joinIndexValue := joinResults.Index(i)
|
joinIndexValue := joinResults.Index(i)
|
||||||
for idx, field := range joinForeignFields {
|
for idx, field := range joinForeignFields {
|
||||||
fieldValues[idx], _ = field.ValueOf(joinIndexValue)
|
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, field := range joinRelForeignFields {
|
for idx, field := range joinRelForeignFields {
|
||||||
joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
|
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||||
@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
|
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
|
||||||
} else {
|
} else {
|
||||||
for _, ref := range rel.References {
|
for _, ref := range rel.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||||
if len(foreignValues) == 0 {
|
if len(foreignValues) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
|
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldValues := make([]interface{}, len(relForeignFields))
|
fieldValues := make([]interface{}, len(relForeignFields))
|
||||||
@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.HasMany, schema.Many2Many:
|
case schema.HasMany, schema.Many2Many:
|
||||||
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||||
default:
|
default:
|
||||||
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
|
||||||
}
|
}
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
for i := 0; i < reflectValue.Len(); i++ {
|
for i := 0; i < reflectValue.Len(); i++ {
|
||||||
switch rel.Type {
|
switch rel.Type {
|
||||||
case schema.HasMany, schema.Many2Many:
|
case schema.HasMany, schema.Many2Many:
|
||||||
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||||
default:
|
default:
|
||||||
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
for i := 0; i < reflectResults.Len(); i++ {
|
for i := 0; i < reflectResults.Len(); i++ {
|
||||||
elem := reflectResults.Index(i)
|
elem := reflectResults.Index(i)
|
||||||
for idx, field := range relForeignFields {
|
for idx, field := range relForeignFields {
|
||||||
fieldValues[idx], _ = field.ValueOf(elem)
|
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
|
||||||
}
|
}
|
||||||
|
|
||||||
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
||||||
if !ok {
|
if !ok {
|
||||||
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists",
|
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
|
||||||
elem.Interface()))
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, data := range datas {
|
for _, data := range datas {
|
||||||
reflectFieldValue := rel.Field.ReflectValueOf(data)
|
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
|
||||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||||
switch reflectFieldValue.Kind() {
|
switch reflectFieldValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
rel.Field.Set(data, elem.Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||||
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
|
||||||
} else {
|
} else {
|
||||||
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
|
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
@ -20,8 +20,10 @@ func Query(db *gorm.DB) {
|
|||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
gorm.Scan(rows, db, 0)
|
defer func() {
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
|
}()
|
||||||
|
gorm.Scan(rows, db, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -40,7 +42,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||||
var conds []clause.Expression
|
var conds []clause.Expression
|
||||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||||
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,13 +96,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// inline joins
|
// inline joins
|
||||||
joins := []clause.Join{}
|
fromClause := clause.From{}
|
||||||
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||||
joins = fromClause.Joins
|
fromClause = v
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(db.Statement.Joins) != 0 || len(joins) != 0 {
|
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||||
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||||
@ -109,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema == nil {
|
if db.Statement.Schema == nil {
|
||||||
joins = append(joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||||
@ -145,12 +147,23 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||||
|
for _, c := range relation.FieldSchema.QueryClauses {
|
||||||
|
onStmt.AddClause(c)
|
||||||
|
}
|
||||||
|
|
||||||
if join.On != nil {
|
if join.On != nil {
|
||||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db}
|
onStmt.AddClause(join.On)
|
||||||
join.On.Build(&onStmt)
|
}
|
||||||
onSQL := onStmt.SQL.String()
|
|
||||||
|
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||||
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
|
where.Build(&onStmt)
|
||||||
|
|
||||||
|
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||||
vars := onStmt.Vars
|
vars := onStmt.Vars
|
||||||
for idx, v := range onStmt.Vars {
|
for idx, v := range vars {
|
||||||
bindvar := strings.Builder{}
|
bindvar := strings.Builder{}
|
||||||
onStmt.Vars = vars[0 : idx+1]
|
onStmt.Vars = vars[0 : idx+1]
|
||||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||||
@ -159,21 +172,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
joins = append(joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
Type: clause.LeftJoin,
|
Type: clause.LeftJoin,
|
||||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||||
ON: clause.Where{Exprs: exprs},
|
ON: clause.Where{Exprs: exprs},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
joins = append(joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Statement.AddClause(fromClause)
|
||||||
db.Statement.Joins = nil
|
db.Statement.Joins = nil
|
||||||
db.Statement.AddClause(clause.From{Joins: joins})
|
|
||||||
} else {
|
} else {
|
||||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||||
}
|
}
|
||||||
@ -186,6 +202,11 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||||
|
if db.Statement.Schema == nil {
|
||||||
|
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
preloadMap := map[string]map[string][]interface{}{}
|
preloadMap := map[string]map[string][]interface{}{}
|
||||||
for name := range db.Statement.Preloads {
|
for name := range db.Statement.Preloads {
|
||||||
preloadFields := strings.Split(name, ".")
|
preloadFields := strings.Split(name, ".")
|
||||||
@ -218,9 +239,20 @@ func Preload(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
sort.Strings(preloadNames)
|
sort.Strings(preloadNames)
|
||||||
|
|
||||||
|
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||||
|
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||||
|
preloadDB.Statement.Settings.Store(k, v)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
for _, name := range preloadNames {
|
||||||
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||||
preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])
|
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||||
} else {
|
} else {
|
||||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
|||||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||||
if _, ok := dest[rel.Name]; ok {
|
if _, ok := dest[rel.Name]; ok {
|
||||||
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
|
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeforeUpdate before update hooks
|
||||||
func BeforeUpdate(db *gorm.DB) {
|
func BeforeUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update update hook
|
||||||
func Update(config *Config) func(db *gorm.DB) {
|
func Update(config *Config) func(db *gorm.DB) {
|
||||||
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
||||||
|
|
||||||
@ -59,6 +61,12 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Statement.Schema != nil {
|
||||||
|
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||||
|
db.Statement.AddClause(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.SQL.Len() == 0 {
|
if db.Statement.SQL.Len() == 0 {
|
||||||
db.Statement.SQL.Grow(180)
|
db.Statement.SQL.Grow(180)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
@ -68,22 +76,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.Schema != nil {
|
|
||||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
|
||||||
db.Statement.AddClause(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement.SQL.Len() == 0 {
|
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
db.Statement.Build(db.Statement.BuildClauses...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
checkMissingWhereConditions(db)
|
||||||
db.AddError(gorm.ErrMissingWhereClause)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !db.DryRun && db.Error == nil {
|
if !db.DryRun && db.Error == nil {
|
||||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||||
@ -105,9 +101,17 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AfterUpdate after update hooks
|
||||||
func AfterUpdate(db *gorm.DB) {
|
func AfterUpdate(db *gorm.DB) {
|
||||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||||
|
if db.Statement.Schema.AfterUpdate {
|
||||||
|
if i, ok := value.(AfterUpdateInterface); ok {
|
||||||
|
called = true
|
||||||
|
db.AddError(i.AfterUpdate(tx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterSave {
|
if db.Statement.Schema.AfterSave {
|
||||||
if i, ok := value.(AfterSaveInterface); ok {
|
if i, ok := value.(AfterSaveInterface); ok {
|
||||||
called = true
|
called = true
|
||||||
@ -115,12 +119,6 @@ func AfterUpdate(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.Statement.Schema.AfterUpdate {
|
|
||||||
if i, ok := value.(AfterUpdateInterface); ok {
|
|
||||||
called = true
|
|
||||||
db.AddError(i.AfterUpdate(tx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return called
|
return called
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -137,13 +135,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
assignValue = func(field *schema.Field, value interface{}) {
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
field.Set(stmt.ReflectValue.Index(i), value)
|
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
assignValue = func(field *schema.Field, value interface{}) {
|
assignValue = func(field *schema.Field, value interface{}) {
|
||||||
if stmt.ReflectValue.CanAddr() {
|
if stmt.ReflectValue.CanAddr() {
|
||||||
field.Set(stmt.ReflectValue, value)
|
field.Set(stmt.Context, stmt.ReflectValue, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -165,7 +163,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||||
var notZero bool
|
var notZero bool
|
||||||
for idx, field := range stmt.Schema.PrimaryFields {
|
for idx, field := range stmt.Schema.PrimaryFields {
|
||||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||||
notZero = notZero || !isZero
|
notZero = notZero || !isZero
|
||||||
}
|
}
|
||||||
@ -178,7 +176,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, field := range stmt.Schema.PrimaryFields {
|
for _, field := range stmt.Schema.PrimaryFields {
|
||||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -232,10 +230,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||||
} else if field.GORMDataType == schema.Time {
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
|
||||||
} else {
|
|
||||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||||
|
} else {
|
||||||
|
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,16 +256,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
if field := updatingSchema.LookUpField(dbName); field != nil {
|
if field := updatingSchema.LookUpField(dbName); field != nil {
|
||||||
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||||
value, isZero := field.ValueOf(updatingValue)
|
value, isZero := field.ValueOf(stmt.Context, updatingValue)
|
||||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano()
|
value = stmt.DB.NowFunc().UnixNano()
|
||||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||||
} else if field.GORMDataType == schema.Time {
|
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||||
value = stmt.DB.NowFunc()
|
|
||||||
} else {
|
|
||||||
value = stmt.DB.NowFunc().Unix()
|
value = stmt.DB.NowFunc().Unix()
|
||||||
|
} else {
|
||||||
|
value = stmt.DB.NowFunc()
|
||||||
}
|
}
|
||||||
isZero = false
|
isZero = false
|
||||||
}
|
}
|
||||||
@ -278,7 +276,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
|
||||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
36
callbacks/visit_map_test.go
Normal file
36
callbacks/visit_map_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package callbacks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||||
|
var vm visitMap
|
||||||
|
var loaded bool
|
||||||
|
type testM struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
t1 := testM{Name: "t1"}
|
||||||
|
t2 := testM{Name: "t2"}
|
||||||
|
t3 := testM{Name: "t3"}
|
||||||
|
|
||||||
|
vm = make(visitMap)
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// t1 already exist but t2 not
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||||
|
t.Fatalf("loaded should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||||
|
t.Fatalf("loaded should be true")
|
||||||
|
}
|
||||||
|
}
|
@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
|||||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||||
tx.Statement.Table = tables[1]
|
tx.Statement.Table = tables[1]
|
||||||
} else {
|
} else if name != "" {
|
||||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||||
tx.Statement.Table = name
|
tx.Statement.Table = name
|
||||||
|
} else {
|
||||||
|
tx.Statement.TableExpr = nil
|
||||||
|
tx.Statement.Table = ""
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -90,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(tx.Statement.Clauses, "SELECT")
|
|
||||||
|
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||||
|
clause.Expression = nil
|
||||||
|
tx.Statement.Clauses["SELECT"] = clause
|
||||||
|
}
|
||||||
case string:
|
case string:
|
||||||
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
|
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
|
||||||
tx.Statement.AddClause(clause.Select{
|
tx.Statement.AddClause(clause.Select{
|
||||||
@ -120,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(tx.Statement.Clauses, "SELECT")
|
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||||
|
clause.Expression = nil
|
||||||
|
tx.Statement.Clauses["SELECT"] = clause
|
||||||
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||||
|
@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) {
|
|||||||
}
|
}
|
||||||
if limit.Offset > 0 {
|
if limit.Offset > 0 {
|
||||||
if limit.Limit > 0 {
|
if limit.Limit > 0 {
|
||||||
builder.WriteString(" ")
|
builder.WriteByte(' ')
|
||||||
}
|
}
|
||||||
builder.WriteString("OFFSET ")
|
builder.WriteString("OFFSET ")
|
||||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||||
|
@ -43,6 +43,23 @@ func TestSelect(t *testing.T) {
|
|||||||
}, clause.From{}},
|
}, clause.From{}},
|
||||||
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
|
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{
|
||||||
|
Expression: clause.CommaExpression{
|
||||||
|
Exprs: []clause.Expression{
|
||||||
|
clause.Expr{
|
||||||
|
SQL: "? as name",
|
||||||
|
Vars: []interface{}{clause.Eq{
|
||||||
|
Column: clause.Column{Name: "age"},
|
||||||
|
Value: 18,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, clause.From{}},
|
||||||
|
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, result := range results {
|
for idx, result := range results {
|
||||||
|
@ -4,6 +4,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AndWithSpace = " AND "
|
||||||
|
OrWithSpace = " OR "
|
||||||
|
)
|
||||||
|
|
||||||
// Where where clause
|
// Where where clause
|
||||||
type Where struct {
|
type Where struct {
|
||||||
Exprs []Expression
|
Exprs []Expression
|
||||||
@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buildExprs(where.Exprs, builder, " AND ")
|
buildExprs(where.Exprs, builder, AndWithSpace)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||||
@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
|||||||
for idx, expr := range exprs {
|
for idx, expr := range exprs {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||||
builder.WriteString(" OR ")
|
builder.WriteString(OrWithSpace)
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(joinCond)
|
builder.WriteString(joinCond)
|
||||||
}
|
}
|
||||||
@ -46,30 +51,30 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
|||||||
case OrConditions:
|
case OrConditions:
|
||||||
if len(v.Exprs) == 1 {
|
if len(v.Exprs) == 1 {
|
||||||
if e, ok := v.Exprs[0].(Expr); ok {
|
if e, ok := v.Exprs[0].(Expr); ok {
|
||||||
sql := strings.ToLower(e.SQL)
|
sql := strings.ToUpper(e.SQL)
|
||||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case AndConditions:
|
case AndConditions:
|
||||||
if len(v.Exprs) == 1 {
|
if len(v.Exprs) == 1 {
|
||||||
if e, ok := v.Exprs[0].(Expr); ok {
|
if e, ok := v.Exprs[0].(Expr); ok {
|
||||||
sql := strings.ToLower(e.SQL)
|
sql := strings.ToUpper(e.SQL)
|
||||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case Expr:
|
case Expr:
|
||||||
sql := strings.ToLower(v.SQL)
|
sql := strings.ToUpper(v.SQL)
|
||||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||||
case NamedExpr:
|
case NamedExpr:
|
||||||
sql := strings.ToLower(v.SQL)
|
sql := strings.ToUpper(v.SQL)
|
||||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if wrapInParentheses {
|
if wrapInParentheses {
|
||||||
builder.WriteString(`(`)
|
builder.WriteByte('(')
|
||||||
expr.Build(builder)
|
expr.Build(builder)
|
||||||
builder.WriteString(`)`)
|
builder.WriteByte(')')
|
||||||
wrapInParentheses = false
|
wrapInParentheses = false
|
||||||
} else {
|
} else {
|
||||||
expr.Build(builder)
|
expr.Build(builder)
|
||||||
@ -110,10 +115,10 @@ type AndConditions struct {
|
|||||||
func (and AndConditions) Build(builder Builder) {
|
func (and AndConditions) Build(builder Builder) {
|
||||||
if len(and.Exprs) > 1 {
|
if len(and.Exprs) > 1 {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
buildExprs(and.Exprs, builder, " AND ")
|
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
} else {
|
} else {
|
||||||
buildExprs(and.Exprs, builder, " AND ")
|
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,10 +136,10 @@ type OrConditions struct {
|
|||||||
func (or OrConditions) Build(builder Builder) {
|
func (or OrConditions) Build(builder Builder) {
|
||||||
if len(or.Exprs) > 1 {
|
if len(or.Exprs) > 1 {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
buildExprs(or.Exprs, builder, " OR ")
|
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
} else {
|
} else {
|
||||||
buildExprs(or.Exprs, builder, " OR ")
|
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) {
|
|||||||
|
|
||||||
for idx, c := range not.Exprs {
|
for idx, c := range not.Exprs {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
builder.WriteString(" AND ")
|
builder.WriteString(AndWithSpace)
|
||||||
}
|
}
|
||||||
|
|
||||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||||
@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) {
|
|||||||
builder.WriteString("NOT ")
|
builder.WriteString("NOT ")
|
||||||
e, wrapInParentheses := c.(Expr)
|
e, wrapInParentheses := c.(Expr)
|
||||||
if wrapInParentheses {
|
if wrapInParentheses {
|
||||||
sql := strings.ToLower(e.SQL)
|
sql := strings.ToUpper(e.SQL)
|
||||||
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
|
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||||
builder.WriteByte('(')
|
builder.WriteByte('(')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,45 @@ func TestWhere(t *testing.T) {
|
|||||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
||||||
[]interface{}{18, "jinzhu"},
|
[]interface{}{18, "jinzhu"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
|
||||||
|
[]interface{}{"1", 18, 100},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
|
||||||
|
[]interface{}{"1", 18, 100},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
|
||||||
|
[]interface{}{"1", 18, 100},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{
|
||||||
|
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
|
||||||
|
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
|
||||||
|
[]interface{}{"1", 100},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||||
|
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
|
||||||
|
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
|
||||||
|
}},
|
||||||
|
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
|
||||||
|
[]interface{}{"1", 100},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, result := range results {
|
for idx, result := range results {
|
||||||
|
@ -39,4 +39,6 @@ var (
|
|||||||
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
|
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
|
||||||
// ErrInvalidValueOfLength invalid values do not match length
|
// ErrInvalidValueOfLength invalid values do not match length
|
||||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||||
|
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||||
|
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||||
)
|
)
|
||||||
|
@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
|
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
|
||||||
|
reflectValue = reflect.Indirect(reflectValue)
|
||||||
|
}
|
||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||||
@ -83,7 +87,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||||
if _, isZero := pf.ValueOf(reflectValue); isZero {
|
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
|
||||||
return tx.callbacks.Create().Execute(tx)
|
return tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -101,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
|
|
||||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||||
if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) {
|
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
|
||||||
return tx.Create(value)
|
return tx.Create(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
batch int
|
batch int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// user specified offset or limit
|
||||||
|
var totalSize int
|
||||||
|
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||||
|
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||||
|
totalSize = limit.Limit
|
||||||
|
|
||||||
|
if totalSize > 0 && batchSize > totalSize {
|
||||||
|
batchSize = totalSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset to offset to 0 in next batch
|
||||||
|
tx = tx.Offset(-1).Session(&Session{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
result := queryDB.Limit(batchSize).Find(dest)
|
result := queryDB.Limit(batchSize).Find(dest)
|
||||||
rowsAffected += result.RowsAffected
|
rowsAffected += result.RowsAffected
|
||||||
@ -192,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if totalSize > 0 {
|
||||||
|
if totalSize <= int(rowsAffected) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if totalSize/batchSize == batch {
|
||||||
|
batchSize = totalSize % batchSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Optimize for-break
|
// Optimize for-break
|
||||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||||
@ -199,7 +227,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +235,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case []clause.Expression:
|
case []clause.Expression:
|
||||||
@ -215,40 +243,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
if eq, ok := expr.(clause.Eq); ok {
|
if eq, ok := expr.(clause.Eq); ok {
|
||||||
switch column := eq.Column.(type) {
|
switch column := eq.Column.(type) {
|
||||||
case string:
|
case string:
|
||||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
||||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||||
}
|
}
|
||||||
case clause.Column:
|
case clause.Column:
|
||||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||||
tx.assignInterfacesToValue(andCond.Exprs)
|
db.assignInterfacesToValue(andCond.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||||
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
|
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
|
||||||
tx.assignInterfacesToValue(exprs)
|
db.assignInterfacesToValue(exprs)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
for _, f := range s.Fields {
|
for _, f := range s.Fields {
|
||||||
if f.Readable {
|
if f.Readable {
|
||||||
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
|
||||||
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||||
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(values) > 0 {
|
} else if len(values) > 0 {
|
||||||
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
||||||
tx.assignInterfacesToValue(exprs)
|
db.assignInterfacesToValue(exprs)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -256,12 +284,13 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
})
|
})
|
||||||
|
|
||||||
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
|
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignInterfacesToValue(where.Exprs)
|
tx.assignInterfacesToValue(where.Exprs)
|
||||||
@ -281,26 +310,28 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.getInstance()
|
||||||
|
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
})
|
})
|
||||||
if tx = queryTx.Find(dest, conds...); tx.Error == nil {
|
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
||||||
if tx.RowsAffected == 0 {
|
if result.RowsAffected == 0 {
|
||||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := c.Expression.(clause.Where); ok {
|
if where, ok := c.Expression.(clause.Where); ok {
|
||||||
tx.assignInterfacesToValue(where.Exprs)
|
result.assignInterfacesToValue(where.Exprs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.attrs) > 0 {
|
if len(db.Statement.attrs) > 0 {
|
||||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize with attrs, conds
|
// initialize with attrs, conds
|
||||||
if len(tx.Statement.assigns) > 0 {
|
if len(db.Statement.assigns) > 0 {
|
||||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Create(dest)
|
return tx.Create(dest)
|
||||||
@ -320,6 +351,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return tx.Model(dest).Updates(assigns)
|
return tx.Model(dest).Updates(assigns)
|
||||||
|
} else {
|
||||||
|
tx.Error = result.Error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tx
|
return tx
|
||||||
@ -585,7 +618,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
var (
|
var (
|
||||||
// clone statement
|
// clone statement
|
||||||
tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
|
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
|
||||||
opt *sql.TxOptions
|
opt *sql.TxOptions
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
@ -594,11 +627,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
switch beginner := tx.Statement.ConnPool.(type) {
|
||||||
|
case TxBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
case ConnPoolBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else {
|
default:
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
17
gorm.go
17
gorm.go
@ -59,6 +59,7 @@ type Config struct {
|
|||||||
cacheStore *sync.Map
|
cacheStore *sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply update config to new config
|
||||||
func (c *Config) Apply(config *Config) error {
|
func (c *Config) Apply(config *Config) error {
|
||||||
if config != c {
|
if config != c {
|
||||||
*config = *c
|
*config = *c
|
||||||
@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AfterInitialize initialize plugins after db connected
|
||||||
func (c *Config) AfterInitialize(db *DB) error {
|
func (c *Config) AfterInitialize(db *DB) error {
|
||||||
if db != nil {
|
if db != nil {
|
||||||
for _, plugin := range c.Plugins {
|
for _, plugin := range c.Plugins {
|
||||||
@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Option gorm option interface
|
||||||
type Option interface {
|
type Option interface {
|
||||||
Apply(*Config) error
|
Apply(*Config) error
|
||||||
AfterInitialize(*DB) error
|
AfterInitialize(*DB) error
|
||||||
@ -96,6 +99,7 @@ type Session struct {
|
|||||||
DryRun bool
|
DryRun bool
|
||||||
PrepareStmt bool
|
PrepareStmt bool
|
||||||
NewDB bool
|
NewDB bool
|
||||||
|
Initialized bool
|
||||||
SkipHooks bool
|
SkipHooks bool
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
DisableNestedTransaction bool
|
DisableNestedTransaction bool
|
||||||
@ -120,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if opt != nil {
|
if opt != nil {
|
||||||
if err := opt.Apply(config); err != nil {
|
if applyErr := opt.Apply(config); applyErr != nil {
|
||||||
return nil, err
|
return nil, applyErr
|
||||||
}
|
}
|
||||||
defer func(opt Option) {
|
defer func(opt Option) {
|
||||||
if errr := opt.AfterInitialize(db); errr != nil {
|
if errr := opt.AfterInitialize(db); errr != nil {
|
||||||
@ -282,6 +286,10 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
tx.Config.NowFunc = config.NowFunc
|
tx.Config.NowFunc = config.NowFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.Initialized {
|
||||||
|
tx = tx.getInstance()
|
||||||
|
}
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -376,10 +384,12 @@ func (db *DB) getInstance() *DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expr returns clause.Expr, which can be used to pass SQL expression as params
|
||||||
func Expr(expr string, args ...interface{}) clause.Expr {
|
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||||
return clause.Expr{SQL: expr, Vars: args}
|
return clause.Expr{SQL: expr, Vars: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetupJoinTable setup join table schema
|
||||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||||
var (
|
var (
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
@ -430,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use use plugin
|
||||||
func (db *DB) Use(plugin Plugin) error {
|
func (db *DB) Use(plugin Plugin) error {
|
||||||
name := plugin.Name()
|
name := plugin.Name()
|
||||||
if _, ok := db.Plugins[name]; ok {
|
if _, ok := db.Plugins[name]; ok {
|
||||||
@ -451,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
|
|||||||
// .First(&User{})
|
// .First(&User{})
|
||||||
// })
|
// })
|
||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||||
tx := queryFn(db.Session(&Session{DryRun: true}))
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||||
stmt := tx.Statement
|
stmt := tx.Statement
|
||||||
|
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
@ -40,24 +40,45 @@ type SavePointerDialectorInterface interface {
|
|||||||
RollbackTo(tx *DB, name string) error
|
RollbackTo(tx *DB, name string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TxBeginner tx beginner
|
||||||
type TxBeginner interface {
|
type TxBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnPoolBeginner conn pool beginner
|
||||||
type ConnPoolBeginner interface {
|
type ConnPoolBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TxCommitter tx committer
|
||||||
type TxCommitter interface {
|
type TxCommitter interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tx sql.Tx interface
|
||||||
|
type Tx interface {
|
||||||
|
ConnPool
|
||||||
|
TxCommitter
|
||||||
|
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
// Valuer gorm valuer interface
|
// Valuer gorm valuer interface
|
||||||
type Valuer interface {
|
type Valuer interface {
|
||||||
GormValue(context.Context, *DB) clause.Expr
|
GormValue(context.Context, *DB) clause.Expr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDBConnector SQL db connector
|
||||||
type GetDBConnector interface {
|
type GetDBConnector interface {
|
||||||
GetDBConn() (*sql.DB, error)
|
GetDBConn() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rows rows interface
|
||||||
|
type Rows interface {
|
||||||
|
Columns() ([]string, error)
|
||||||
|
ColumnTypes() ([]*sql.ColumnType, error)
|
||||||
|
Next() bool
|
||||||
|
Scan(dest ...interface{}) error
|
||||||
|
Err() error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrRecordNotFound record not found error
|
||||||
var ErrRecordNotFound = errors.New("record not found")
|
var ErrRecordNotFound = errors.New("record not found")
|
||||||
|
|
||||||
// Colors
|
// Colors
|
||||||
@ -30,13 +31,17 @@ const (
|
|||||||
YellowBold = "\033[33;1m"
|
YellowBold = "\033[33;1m"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogLevel
|
// LogLevel log level
|
||||||
type LogLevel int
|
type LogLevel int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Silent silent log level
|
||||||
Silent LogLevel = iota + 1
|
Silent LogLevel = iota + 1
|
||||||
|
// Error error log level
|
||||||
Error
|
Error
|
||||||
|
// Warn warn log level
|
||||||
Warn
|
Warn
|
||||||
|
// Info info log level
|
||||||
Info
|
Info
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,6 +50,7 @@ type Writer interface {
|
|||||||
Printf(string, ...interface{})
|
Printf(string, ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Config logger config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
SlowThreshold time.Duration
|
SlowThreshold time.Duration
|
||||||
Colorful bool
|
Colorful bool
|
||||||
@ -62,16 +68,20 @@ type Interface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
// Discard Discard logger will print any log to ioutil.Discard
|
||||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
||||||
|
// Default Default logger
|
||||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
LogLevel: Warn,
|
LogLevel: Warn,
|
||||||
IgnoreRecordNotFoundError: false,
|
IgnoreRecordNotFoundError: false,
|
||||||
Colorful: true,
|
Colorful: true,
|
||||||
})
|
})
|
||||||
|
// Recorder Recorder logger records running SQL into a recorder instance
|
||||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// New initialize logger
|
||||||
func New(writer Writer, config Config) Interface {
|
func New(writer Writer, config Config) Interface {
|
||||||
var (
|
var (
|
||||||
infoStr = "%s\n[info] "
|
infoStr = "%s\n[info] "
|
||||||
@ -179,10 +189,12 @@ type traceRecorder struct {
|
|||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New new trace recorder
|
||||||
func (l traceRecorder) New() *traceRecorder {
|
func (l traceRecorder) New() *traceRecorder {
|
||||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace implement logger interface
|
||||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
l.BeginAt = begin
|
l.BeginAt = begin
|
||||||
l.SQL, l.RowsAffected = fc()
|
l.SQL, l.RowsAffected = fc()
|
||||||
|
@ -19,9 +19,9 @@ const (
|
|||||||
nullStr = "NULL"
|
nullStr = "NULL"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isPrintable(s []byte) bool {
|
func isPrintable(s string) bool {
|
||||||
for _, r := range s {
|
for _, r := range s {
|
||||||
if !unicode.IsPrint(rune(r)) {
|
if !unicode.IsPrint(r) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
|
|||||||
|
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var convertParams func(interface{}, int)
|
var (
|
||||||
vars := make([]string, len(avars))
|
convertParams func(interface{}, int)
|
||||||
|
vars = make([]string, len(avars))
|
||||||
|
)
|
||||||
|
|
||||||
convertParams = func(v interface{}, idx int) {
|
convertParams = func(v interface{}, idx int) {
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
@ -64,14 +67,25 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
}
|
}
|
||||||
case fmt.Stringer:
|
case fmt.Stringer:
|
||||||
reflectValue := reflect.ValueOf(v)
|
reflectValue := reflect.ValueOf(v)
|
||||||
|
switch reflectValue.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
||||||
|
case reflect.Bool:
|
||||||
|
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||||
|
case reflect.String:
|
||||||
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||||
|
default:
|
||||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = nullStr
|
vars[idx] = nullStr
|
||||||
}
|
}
|
||||||
|
}
|
||||||
case []byte:
|
case []byte:
|
||||||
if isPrintable(v) {
|
if s := string(v); isPrintable(s) {
|
||||||
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
||||||
} else {
|
} else {
|
||||||
vars[idx] = escaper + "<binary>" + escaper
|
vars[idx] = escaper + "<binary>" + escaper
|
||||||
}
|
}
|
||||||
@ -80,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
case float64, float32:
|
case float64, float32:
|
||||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||||
case string:
|
case string:
|
||||||
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||||
default:
|
default:
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||||
@ -97,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func format(v []byte, escaper string) string {
|
func format(v []byte, escaper string) string {
|
||||||
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExplainSQL(t *testing.T) {
|
func TestExplainSQL(t *testing.T) {
|
||||||
|
13
migrator.go
13
migrator.go
@ -1,6 +1,8 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
@ -33,14 +35,23 @@ type ViewOption struct {
|
|||||||
Query *DB
|
Query *DB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ColumnType column type interface
|
||||||
type ColumnType interface {
|
type ColumnType interface {
|
||||||
Name() string
|
Name() string
|
||||||
DatabaseTypeName() string
|
DatabaseTypeName() string // varchar
|
||||||
|
ColumnType() (columnType string, ok bool) // varchar(64)
|
||||||
|
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||||
|
AutoIncrement() (isAutoIncrement bool, ok bool)
|
||||||
Length() (length int64, ok bool)
|
Length() (length int64, ok bool)
|
||||||
DecimalSize() (precision int64, scale int64, ok bool)
|
DecimalSize() (precision int64, scale int64, ok bool)
|
||||||
Nullable() (nullable bool, ok bool)
|
Nullable() (nullable bool, ok bool)
|
||||||
|
Unique() (unique bool, ok bool)
|
||||||
|
ScanType() reflect.Type
|
||||||
|
Comment() (value string, ok bool)
|
||||||
|
DefaultValue() (value string, ok bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Migrator migrator interface
|
||||||
type Migrator interface {
|
type Migrator interface {
|
||||||
// AutoMigrate
|
// AutoMigrate
|
||||||
AutoMigrate(dst ...interface{}) error
|
AutoMigrate(dst ...interface{}) error
|
||||||
|
107
migrator/column_type.go
Normal file
107
migrator/column_type.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package migrator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ColumnType column type implements ColumnType interface
|
||||||
|
type ColumnType struct {
|
||||||
|
SQLColumnType *sql.ColumnType
|
||||||
|
NameValue sql.NullString
|
||||||
|
DataTypeValue sql.NullString
|
||||||
|
ColumnTypeValue sql.NullString
|
||||||
|
PrimaryKeyValue sql.NullBool
|
||||||
|
UniqueValue sql.NullBool
|
||||||
|
AutoIncrementValue sql.NullBool
|
||||||
|
LengthValue sql.NullInt64
|
||||||
|
DecimalSizeValue sql.NullInt64
|
||||||
|
ScaleValue sql.NullInt64
|
||||||
|
NullableValue sql.NullBool
|
||||||
|
ScanTypeValue reflect.Type
|
||||||
|
CommentValue sql.NullString
|
||||||
|
DefaultValueValue sql.NullString
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name or alias of the column.
|
||||||
|
func (ct ColumnType) Name() string {
|
||||||
|
if ct.NameValue.Valid {
|
||||||
|
return ct.NameValue.String
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseTypeName returns the database system name of the column type. If an empty
|
||||||
|
// string is returned, then the driver type name is not supported.
|
||||||
|
// Consult your driver documentation for a list of driver data types. Length specifiers
|
||||||
|
// are not included.
|
||||||
|
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
|
||||||
|
// "INT", and "BIGINT".
|
||||||
|
func (ct ColumnType) DatabaseTypeName() string {
|
||||||
|
if ct.DataTypeValue.Valid {
|
||||||
|
return ct.DataTypeValue.String
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.DatabaseTypeName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnType returns the database type of the column. like `varchar(16)`
|
||||||
|
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
|
||||||
|
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrimaryKey returns the column is primary key or not.
|
||||||
|
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||||
|
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoIncrement returns the column is auto increment or not.
|
||||||
|
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
|
||||||
|
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length returns the column type length for variable length column types
|
||||||
|
func (ct ColumnType) Length() (length int64, ok bool) {
|
||||||
|
if ct.LengthValue.Valid {
|
||||||
|
return ct.LengthValue.Int64, true
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.Length()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecimalSize returns the scale and precision of a decimal type.
|
||||||
|
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
|
||||||
|
if ct.DecimalSizeValue.Valid {
|
||||||
|
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.DecimalSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nullable reports whether the column may be null.
|
||||||
|
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
|
||||||
|
if ct.NullableValue.Valid {
|
||||||
|
return ct.NullableValue.Bool, true
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.Nullable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique reports whether the column may be unique.
|
||||||
|
func (ct ColumnType) Unique() (unique bool, ok bool) {
|
||||||
|
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
|
||||||
|
func (ct ColumnType) ScanType() reflect.Type {
|
||||||
|
if ct.ScanTypeValue != nil {
|
||||||
|
return ct.ScanTypeValue
|
||||||
|
}
|
||||||
|
return ct.SQLColumnType.ScanType()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment returns the comment of current column.
|
||||||
|
func (ct ColumnType) Comment() (value string, ok bool) {
|
||||||
|
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultValue returns the default value of current column.
|
||||||
|
func (ct ColumnType) DefaultValue() (value string, ok bool) {
|
||||||
|
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
|
||||||
|
}
|
@ -30,10 +30,12 @@ type Config struct {
|
|||||||
gorm.Dialector
|
gorm.Dialector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GormDataTypeInterface gorm data type interface
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunWithValue run migration with statement value
|
||||||
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
|
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
|
||||||
stmt := &gorm.Statement{DB: m.DB}
|
stmt := &gorm.Statement{DB: m.DB}
|
||||||
if m.DB.Statement != nil {
|
if m.DB.Statement != nil {
|
||||||
@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
|
|||||||
return fc(stmt)
|
return fc(stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DataTypeOf return field's db data type
|
||||||
func (m Migrator) DataTypeOf(field *schema.Field) string {
|
func (m Migrator) DataTypeOf(field *schema.Field) string {
|
||||||
fieldValue := reflect.New(field.IndirectFieldType)
|
fieldValue := reflect.New(field.IndirectFieldType)
|
||||||
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
||||||
@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
|
|||||||
return m.Dialector.DataTypeOf(field)
|
return m.Dialector.DataTypeOf(field)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FullDataTypeOf returns field's db full data type
|
||||||
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||||
expr.SQL = m.DataTypeOf(field)
|
expr.SQL = m.DataTypeOf(field)
|
||||||
|
|
||||||
@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoMigrate
|
// AutoMigrate auto migrate values
|
||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, true) {
|
for _, value := range m.ReorderModels(values, true) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
@ -95,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, dbName := range stmt.Schema.DBNames {
|
for _, dbName := range stmt.Schema.DBNames {
|
||||||
field := stmt.Schema.FieldsByDBName[dbName]
|
field := stmt.Schema.FieldsByDBName[dbName]
|
||||||
@ -156,12 +163,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTables returns tables
|
||||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||||
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
||||||
Scan(&tableList).Error
|
Scan(&tableList).Error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateTable create table in database for values
|
||||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, false) {
|
for _, value := range m.ReorderModels(values, false) {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
@ -252,6 +261,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropTable drop table for values
|
||||||
func (m Migrator) DropTable(values ...interface{}) error {
|
func (m Migrator) DropTable(values ...interface{}) error {
|
||||||
values = m.ReorderModels(values, false)
|
values = m.ReorderModels(values, false)
|
||||||
for i := len(values) - 1; i >= 0; i-- {
|
for i := len(values) - 1; i >= 0; i-- {
|
||||||
@ -265,6 +275,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasTable returns table exists or not for value, value could be a struct or string
|
||||||
func (m Migrator) HasTable(value interface{}) bool {
|
func (m Migrator) HasTable(value interface{}) bool {
|
||||||
var count int64
|
var count int64
|
||||||
|
|
||||||
@ -276,6 +287,7 @@ func (m Migrator) HasTable(value interface{}) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenameTable rename table from oldName to newName
|
||||||
func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||||
var oldTable, newTable interface{}
|
var oldTable, newTable interface{}
|
||||||
if v, ok := oldName.(string); ok {
|
if v, ok := oldName.(string); ok {
|
||||||
@ -303,12 +315,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
|||||||
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
|
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) AddColumn(value interface{}, field string) error {
|
// AddColumn create `name` column for value
|
||||||
|
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
// avoid using the same name field
|
// avoid using the same name field
|
||||||
f := stmt.Schema.LookUpField(field)
|
f := stmt.Schema.LookUpField(name)
|
||||||
if f == nil {
|
if f == nil {
|
||||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !f.IgnoreMigration {
|
if !f.IgnoreMigration {
|
||||||
@ -322,6 +335,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropColumn drop value's `name` column
|
||||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||||
@ -334,10 +348,11 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AlterColumn alter value's `field` column' type based on schema definition
|
||||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||||
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
|
fileType := m.FullDataTypeOf(field)
|
||||||
return m.DB.Exec(
|
return m.DB.Exec(
|
||||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||||
@ -348,6 +363,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasColumn check has column `field` for value or not
|
||||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||||
var count int64
|
var count int64
|
||||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
@ -366,6 +382,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenameColumn rename value's field name from oldName to newName
|
||||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||||
@ -383,6 +400,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MigrateColumn migrate column
|
||||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||||
// found, smart migrate
|
// found, smart migrate
|
||||||
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
|
||||||
@ -421,6 +439,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check unique
|
||||||
|
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
|
||||||
|
// not primary key
|
||||||
|
if !field.PrimaryKey {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check default value
|
||||||
|
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
|
||||||
|
// not primary key
|
||||||
|
if !field.PrimaryKey {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check comment
|
||||||
|
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
|
||||||
|
// not primary key
|
||||||
|
if !field.PrimaryKey {
|
||||||
|
alterColumn = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if alterColumn && !field.IgnoreMigration {
|
if alterColumn && !field.IgnoreMigration {
|
||||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||||
}
|
}
|
||||||
@ -448,7 +490,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range rawColumnTypes {
|
for _, c := range rawColumnTypes {
|
||||||
columnTypes = append(columnTypes, c)
|
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -457,10 +499,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
|||||||
return columnTypes, execErr
|
return columnTypes, execErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateView create view
|
||||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||||
return gorm.ErrNotImplemented
|
return gorm.ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropView drop view
|
||||||
func (m Migrator) DropView(name string) error {
|
func (m Migrator) DropView(name string) error {
|
||||||
return gorm.ErrNotImplemented
|
return gorm.ErrNotImplemented
|
||||||
}
|
}
|
||||||
@ -487,6 +531,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
||||||
if stmt.Schema == nil {
|
if stmt.Schema == nil {
|
||||||
return nil, nil, stmt.Table
|
return nil, nil, stmt.Table
|
||||||
@ -531,6 +576,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
|
|||||||
return nil, nil, stmt.Schema.Table
|
return nil, nil, stmt.Schema.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateConstraint create constraint
|
||||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||||
@ -554,6 +600,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropConstraint drop constraint
|
||||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||||
@ -566,6 +613,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasConstraint check has constraint or not
|
||||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||||
var count int64
|
var count int64
|
||||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
@ -586,6 +634,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BuildIndexOptions build index options
|
||||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
str := stmt.Quote(opt.DBName)
|
str := stmt.Quote(opt.DBName)
|
||||||
@ -607,10 +656,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BuildIndexOptionsInterface build index options interface
|
||||||
type BuildIndexOptionsInterface interface {
|
type BuildIndexOptionsInterface interface {
|
||||||
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
|
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateIndex create index `name`
|
||||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||||
@ -642,6 +693,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropIndex drop index `name`
|
||||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||||
@ -652,6 +704,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasIndex check has index `name` or not
|
||||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||||
var count int64
|
var count int64
|
||||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
@ -669,6 +722,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
|||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenameIndex rename index from oldName to newName
|
||||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
return m.DB.Exec(
|
return m.DB.Exec(
|
||||||
@ -678,6 +732,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CurrentDatabase returns current database name
|
||||||
func (m Migrator) CurrentDatabase() (name string) {
|
func (m Migrator) CurrentDatabase() (name string) {
|
||||||
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
|
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
|
||||||
return
|
return
|
||||||
@ -704,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||||
}
|
}
|
||||||
beDependedOn := map[*schema.Schema]bool{}
|
beDependedOn := map[*schema.Schema]bool{}
|
||||||
if err := dep.Parse(value); err != nil {
|
// support for special table name
|
||||||
|
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
|
||||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||||
}
|
}
|
||||||
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||||
@ -781,6 +837,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CurrentTable returns current statement's table expression
|
||||||
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
||||||
if stmt.TableExpr != nil {
|
if stmt.TableExpr != nil {
|
||||||
return *stmt.TableExpr
|
return *stmt.TableExpr
|
||||||
|
@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStmtTX struct {
|
type PreparedStmtTX struct {
|
||||||
*sql.Tx
|
Tx
|
||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
|||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
|
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tx.PreparedStmtDB.Mux.Lock()
|
tx.PreparedStmtDB.Mux.Lock()
|
||||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||||
|
105
scan.go
105
scan.go
@ -10,6 +10,7 @@ import (
|
|||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// prepareValues prepare values slice
|
||||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
for idx, name := range columns {
|
for idx, name := range columns {
|
||||||
@ -49,65 +50,56 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||||
for idx, column := range columns {
|
for idx, field := range fields {
|
||||||
if sch == nil {
|
if field != nil {
|
||||||
values[idx] = reflectValue.Interface()
|
values[idx] = field.NewValuePool.Get()
|
||||||
} else if field := sch.LookUpField(column); field != nil && field.Readable {
|
} else if len(fields) == 1 {
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
if reflectValue.CanAddr() {
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
values[idx] = reflectValue.Addr().Interface()
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else if len(columns) == 1 {
|
|
||||||
sch = nil
|
|
||||||
values[idx] = reflectValue.Interface()
|
|
||||||
} else {
|
} else {
|
||||||
values[idx] = &sql.RawBytes{}
|
values[idx] = reflectValue.Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
if sch != nil {
|
for idx, field := range fields {
|
||||||
for idx, column := range columns {
|
if field != nil {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||||
field.Set(reflectValue, values[idx])
|
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
|
||||||
relValue := rel.Field.ReflectValueOf(reflectValue)
|
|
||||||
value := reflect.ValueOf(values[idx]).Elem()
|
|
||||||
|
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
if value.IsNil() {
|
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
||||||
|
}
|
||||||
|
|
||||||
field.Set(relValue, values[idx])
|
// release data to pool
|
||||||
}
|
field.NewValuePool.Put(values[idx])
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanMode scan data mode
|
||||||
type ScanMode uint8
|
type ScanMode uint8
|
||||||
|
|
||||||
|
// scan modes
|
||||||
const (
|
const (
|
||||||
ScanInitialized ScanMode = 1 << 0 // 1
|
ScanInitialized ScanMode = 1 << 0 // 1
|
||||||
ScanUpdate ScanMode = 1 << 1 // 2
|
ScanUpdate ScanMode = 1 << 1 // 2
|
||||||
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
// Scan scan rows into db statement
|
||||||
|
func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||||
var (
|
var (
|
||||||
columns, _ = rows.Columns()
|
columns, _ = rows.Columns()
|
||||||
values = make([]interface{}, len(columns))
|
values = make([]interface{}, len(columns))
|
||||||
@ -138,7 +130,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
}
|
}
|
||||||
scanIntoMap(mapValue, values, columns)
|
scanIntoMap(mapValue, values, columns)
|
||||||
}
|
}
|
||||||
case *[]map[string]interface{}, []map[string]interface{}:
|
case *[]map[string]interface{}:
|
||||||
columnTypes, _ := rows.ColumnTypes()
|
columnTypes, _ := rows.ColumnTypes()
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
prepareValues(values, db, columnTypes, columns)
|
prepareValues(values, db, columnTypes, columns)
|
||||||
@ -149,11 +141,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
|
|
||||||
mapValue := map[string]interface{}{}
|
mapValue := map[string]interface{}{}
|
||||||
scanIntoMap(mapValue, values, columns)
|
scanIntoMap(mapValue, values, columns)
|
||||||
if values, ok := dest.([]map[string]interface{}); ok {
|
*dest = append(*dest, mapValue)
|
||||||
values = append(values, mapValue)
|
|
||||||
} else if values, ok := dest.(*[]map[string]interface{}); ok {
|
|
||||||
*values = append(*values, mapValue)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case *int, *int8, *int16, *int32, *int64,
|
case *int, *int8, *int16, *int32, *int64,
|
||||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||||
@ -169,6 +157,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
fields = make([]*schema.Field, len(columns))
|
fields = make([]*schema.Field, len(columns))
|
||||||
|
selectedColumnsMap = make(map[string]int, len(columns))
|
||||||
joinFields [][2]*schema.Field
|
joinFields [][2]*schema.Field
|
||||||
sch = db.Statement.Schema
|
sch = db.Statement.Schema
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
@ -193,9 +182,31 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(columns) == 1 {
|
||||||
|
// Is Pluck
|
||||||
|
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||||
|
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||||
|
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||||
|
sch = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not Pluck
|
||||||
|
if sch != nil {
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
|
if curIndex, ok := selectedColumnsMap[column]; ok {
|
||||||
|
for fieldIndex, selectField := range sch.Fields[curIndex+1:] {
|
||||||
|
if selectField.DBName == column && selectField.Readable {
|
||||||
|
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
||||||
|
fields[idx] = selectField
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
fields[idx] = field
|
fields[idx] = field
|
||||||
|
selectedColumnsMap[column] = idx
|
||||||
|
}
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
@ -213,14 +224,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
values[idx] = &sql.RawBytes{}
|
values[idx] = &sql.RawBytes{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(columns) == 1 {
|
|
||||||
// isPluck
|
|
||||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
|
||||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
|
||||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
|
||||||
sch = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,7 +247,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
elem = reflectValue.Index(int(db.RowsAffected))
|
elem = reflectValue.Index(int(db.RowsAffected))
|
||||||
if onConflictDonothing {
|
if onConflictDonothing {
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if _, ok := field.ValueOf(elem); !ok {
|
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
goto BEGIN
|
goto BEGIN
|
||||||
}
|
}
|
||||||
@ -254,7 +257,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
elem = reflect.New(reflectValueType)
|
elem = reflect.New(reflectValueType)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||||
|
|
||||||
if !update {
|
if !update {
|
||||||
if isPtr {
|
if isPtr {
|
||||||
@ -270,7 +273,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
|
|||||||
}
|
}
|
||||||
case reflect.Struct, reflect.Ptr:
|
case reflect.Struct, reflect.Ptr:
|
||||||
if initialized || rows.Next() {
|
if initialized || rows.Next() {
|
||||||
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
|
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
db.AddError(rows.Scan(dest))
|
db.AddError(rows.Scan(dest))
|
||||||
|
534
schema/field.go
534
schema/field.go
@ -1,6 +1,7 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -11,15 +12,25 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jinzhu/now"
|
"github.com/jinzhu/now"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DataType string
|
// special types' reflect type
|
||||||
|
var (
|
||||||
|
TimeReflectType = reflect.TypeOf(time.Time{})
|
||||||
|
TimePtrReflectType = reflect.TypeOf(&time.Time{})
|
||||||
|
ByteReflectType = reflect.TypeOf(uint8(0))
|
||||||
|
)
|
||||||
|
|
||||||
type TimeType int64
|
type (
|
||||||
|
// DataType GORM data type
|
||||||
var TimeReflectType = reflect.TypeOf(time.Time{})
|
DataType string
|
||||||
|
// TimeType GORM time type
|
||||||
|
TimeType int64
|
||||||
|
)
|
||||||
|
|
||||||
|
// GORM time types
|
||||||
const (
|
const (
|
||||||
UnixTime TimeType = 1
|
UnixTime TimeType = 1
|
||||||
UnixSecond TimeType = 2
|
UnixSecond TimeType = 2
|
||||||
@ -27,6 +38,7 @@ const (
|
|||||||
UnixNanosecond TimeType = 4
|
UnixNanosecond TimeType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GORM fields types
|
||||||
const (
|
const (
|
||||||
Bool DataType = "bool"
|
Bool DataType = "bool"
|
||||||
Int DataType = "int"
|
Int DataType = "int"
|
||||||
@ -37,6 +49,7 @@ const (
|
|||||||
Bytes DataType = "bytes"
|
Bytes DataType = "bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Field is the representation of model schema's field
|
||||||
type Field struct {
|
type Field struct {
|
||||||
Name string
|
Name string
|
||||||
DBName string
|
DBName string
|
||||||
@ -49,9 +62,9 @@ type Field struct {
|
|||||||
Creatable bool
|
Creatable bool
|
||||||
Updatable bool
|
Updatable bool
|
||||||
Readable bool
|
Readable bool
|
||||||
HasDefaultValue bool
|
|
||||||
AutoCreateTime TimeType
|
AutoCreateTime TimeType
|
||||||
AutoUpdateTime TimeType
|
AutoUpdateTime TimeType
|
||||||
|
HasDefaultValue bool
|
||||||
DefaultValue string
|
DefaultValue string
|
||||||
DefaultValueInterface interface{}
|
DefaultValueInterface interface{}
|
||||||
NotNull bool
|
NotNull bool
|
||||||
@ -60,6 +73,7 @@ type Field struct {
|
|||||||
Size int
|
Size int
|
||||||
Precision int
|
Precision int
|
||||||
Scale int
|
Scale int
|
||||||
|
IgnoreMigration bool
|
||||||
FieldType reflect.Type
|
FieldType reflect.Type
|
||||||
IndirectFieldType reflect.Type
|
IndirectFieldType reflect.Type
|
||||||
StructField reflect.StructField
|
StructField reflect.StructField
|
||||||
@ -68,27 +82,39 @@ type Field struct {
|
|||||||
Schema *Schema
|
Schema *Schema
|
||||||
EmbeddedSchema *Schema
|
EmbeddedSchema *Schema
|
||||||
OwnerSchema *Schema
|
OwnerSchema *Schema
|
||||||
ReflectValueOf func(reflect.Value) reflect.Value
|
ReflectValueOf func(context.Context, reflect.Value) reflect.Value
|
||||||
ValueOf func(reflect.Value) (value interface{}, zero bool)
|
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
|
||||||
Set func(reflect.Value, interface{}) error
|
Set func(context.Context, reflect.Value, interface{}) error
|
||||||
IgnoreMigration bool
|
Serializer SerializerInterface
|
||||||
|
NewValuePool FieldNewValuePool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseField parses reflect.StructField to Field
|
||||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||||
var err error
|
var (
|
||||||
|
err error
|
||||||
|
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
|
||||||
|
)
|
||||||
|
|
||||||
field := &Field{
|
field := &Field{
|
||||||
Name: fieldStruct.Name,
|
Name: fieldStruct.Name,
|
||||||
|
DBName: tagSetting["COLUMN"],
|
||||||
BindNames: []string{fieldStruct.Name},
|
BindNames: []string{fieldStruct.Name},
|
||||||
FieldType: fieldStruct.Type,
|
FieldType: fieldStruct.Type,
|
||||||
IndirectFieldType: fieldStruct.Type,
|
IndirectFieldType: fieldStruct.Type,
|
||||||
StructField: fieldStruct,
|
StructField: fieldStruct,
|
||||||
|
Tag: fieldStruct.Tag,
|
||||||
|
TagSettings: tagSetting,
|
||||||
|
Schema: schema,
|
||||||
Creatable: true,
|
Creatable: true,
|
||||||
Updatable: true,
|
Updatable: true,
|
||||||
Readable: true,
|
Readable: true,
|
||||||
Tag: fieldStruct.Tag,
|
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
|
||||||
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
|
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
|
||||||
Schema: schema,
|
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
|
||||||
|
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||||
|
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||||
|
Comment: tagSetting["COMMENT"],
|
||||||
AutoIncrementIncrement: 1,
|
AutoIncrementIncrement: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +123,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fieldValue := reflect.New(field.IndirectFieldType)
|
fieldValue := reflect.New(field.IndirectFieldType)
|
||||||
// if field is valuer, used its value or first fields as data type
|
// if field is valuer, used its value or first field as data type
|
||||||
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
||||||
if isValuer {
|
if isValuer {
|
||||||
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
||||||
@ -105,31 +131,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
fieldValue = reflect.ValueOf(v)
|
fieldValue = reflect.ValueOf(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
|
||||||
var getRealFieldValue func(reflect.Value)
|
var getRealFieldValue func(reflect.Value)
|
||||||
getRealFieldValue = func(v reflect.Value) {
|
getRealFieldValue = func(v reflect.Value) {
|
||||||
rv := reflect.Indirect(v)
|
var (
|
||||||
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
|
rv = reflect.Indirect(v)
|
||||||
for i := 0; i < rv.Type().NumField(); i++ {
|
rvType = rv.Type()
|
||||||
newFieldType := rv.Type().Field(i).Type
|
)
|
||||||
|
|
||||||
|
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
|
||||||
|
for i := 0; i < rvType.NumField(); i++ {
|
||||||
|
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
|
||||||
|
if _, ok := field.TagSettings[key]; !ok {
|
||||||
|
field.TagSettings[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < rvType.NumField(); i++ {
|
||||||
|
newFieldType := rvType.Field(i).Type
|
||||||
for newFieldType.Kind() == reflect.Ptr {
|
for newFieldType.Kind() == reflect.Ptr {
|
||||||
newFieldType = newFieldType.Elem()
|
newFieldType = newFieldType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldValue = reflect.New(newFieldType)
|
fieldValue = reflect.New(newFieldType)
|
||||||
|
if rvType != reflect.Indirect(fieldValue).Type() {
|
||||||
if rv.Type() != reflect.Indirect(fieldValue).Type() {
|
|
||||||
getRealFieldValue(fieldValue)
|
getRealFieldValue(fieldValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if fieldValue.IsValid() {
|
if fieldValue.IsValid() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
|
|
||||||
if _, ok := field.TagSettings[key]; !ok {
|
|
||||||
field.TagSettings[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -138,19 +170,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbName, ok := field.TagSettings["COLUMN"]; ok {
|
if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
|
||||||
field.DBName = dbName
|
field.DataType = String
|
||||||
|
field.Serializer = v
|
||||||
|
} else {
|
||||||
|
var serializerName = field.TagSettings["JSON"]
|
||||||
|
if serializerName == "" {
|
||||||
|
serializerName = field.TagSettings["SERIALIZER"]
|
||||||
|
}
|
||||||
|
if serializerName != "" {
|
||||||
|
if serializer, ok := GetSerializer(serializerName); ok {
|
||||||
|
// Set default data type to string for serializer
|
||||||
|
field.DataType = String
|
||||||
|
field.Serializer = serializer
|
||||||
|
} else {
|
||||||
|
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.PrimaryKey = true
|
|
||||||
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.PrimaryKey = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.AutoIncrement = true
|
|
||||||
field.HasDefaultValue = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
|
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
|
||||||
@ -176,20 +212,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
field.Scale, _ = strconv.Atoi(s)
|
field.Scale, _ = strconv.Atoi(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.NotNull = true
|
|
||||||
} else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.NotNull = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
|
|
||||||
field.Unique = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if val, ok := field.TagSettings["COMMENT"]; ok {
|
|
||||||
field.Comment = val
|
|
||||||
}
|
|
||||||
|
|
||||||
// default value is function or null or blank (primary keys)
|
// default value is function or null or blank (primary keys)
|
||||||
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
|
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
|
||||||
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
||||||
@ -225,7 +247,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.DataType = String
|
field.DataType = String
|
||||||
|
|
||||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||||
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
||||||
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
|
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
|
||||||
@ -236,22 +257,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
|
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
|
||||||
field.DataType = Time
|
field.DataType = Time
|
||||||
}
|
}
|
||||||
|
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
|
||||||
|
if t, err := now.Parse(field.DefaultValue); err == nil {
|
||||||
|
field.DefaultValueInterface = t
|
||||||
|
}
|
||||||
|
}
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) {
|
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
|
||||||
field.DataType = Bytes
|
field.DataType = Bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
field.GORMDataType = field.DataType
|
|
||||||
|
|
||||||
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
||||||
field.DataType = DataType(dataTyper.GormDataType())
|
field.DataType = DataType(dataTyper.GormDataType())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if field.DataType == Time {
|
if field.DataType == Time {
|
||||||
field.AutoCreateTime = UnixTime
|
field.AutoCreateTime = UnixTime
|
||||||
} else if strings.ToUpper(v) == "NANO" {
|
} else if strings.ToUpper(v) == "NANO" {
|
||||||
@ -263,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||||
if field.DataType == Time {
|
if field.DataType == Time {
|
||||||
field.AutoUpdateTime = UnixTime
|
field.AutoUpdateTime = UnixTime
|
||||||
} else if strings.ToUpper(v) == "NANO" {
|
} else if strings.ToUpper(v) == "NANO" {
|
||||||
@ -275,6 +299,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if field.GORMDataType == "" {
|
||||||
|
field.GORMDataType = field.DataType
|
||||||
|
}
|
||||||
|
|
||||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||||
switch DataType(strings.ToLower(val)) {
|
switch DataType(strings.ToLower(val)) {
|
||||||
case Bool, Int, Uint, Float, String, Time, Bytes:
|
case Bool, Int, Uint, Float, String, Time, Bytes:
|
||||||
@ -284,10 +312,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.GORMDataType == "" {
|
|
||||||
field.GORMDataType = field.DataType
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.Size == 0 {
|
if field.Size == 0 {
|
||||||
switch reflect.Indirect(fieldValue).Kind() {
|
switch reflect.Indirect(fieldValue).Kind() {
|
||||||
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
|
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
|
||||||
@ -346,8 +370,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes &&
|
// Normal anonymous field or having `EMBEDDED` tag
|
||||||
(ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) {
|
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
|
||||||
|
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
|
||||||
kind := reflect.Indirect(fieldValue).Kind()
|
kind := reflect.Indirect(fieldValue).Kind()
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
@ -410,31 +435,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
|
|
||||||
// create valuer, setter when parse struct
|
// create valuer, setter when parse struct
|
||||||
func (field *Field) setupValuerAndSetter() {
|
func (field *Field) setupValuerAndSetter() {
|
||||||
// ValueOf
|
// Setup NewValuePool
|
||||||
|
field.setupNewValuePool()
|
||||||
|
|
||||||
|
// ValueOf returns field's value and if it is zero
|
||||||
|
fieldIndex := field.StructField.Index[0]
|
||||||
switch {
|
switch {
|
||||||
case len(field.StructField.Index) == 1:
|
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||||
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
|
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||||
return fieldValue.Interface(), fieldValue.IsZero()
|
|
||||||
}
|
|
||||||
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
|
|
||||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
|
||||||
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
|
|
||||||
return fieldValue.Interface(), fieldValue.IsZero()
|
return fieldValue.Interface(), fieldValue.IsZero()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
v := reflect.Indirect(value)
|
v = reflect.Indirect(v)
|
||||||
|
for _, fieldIdx := range field.StructField.Index {
|
||||||
for _, idx := range field.StructField.Index {
|
if fieldIdx >= 0 {
|
||||||
if idx >= 0 {
|
v = v.Field(fieldIdx)
|
||||||
v = v.Field(idx)
|
|
||||||
} else {
|
} else {
|
||||||
v = v.Field(-idx - 1)
|
v = v.Field(-fieldIdx - 1)
|
||||||
|
|
||||||
if v.Type().Elem().Kind() != reflect.Struct {
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !v.IsNil() {
|
if !v.IsNil() {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
@ -443,36 +462,53 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return v.Interface(), v.IsZero()
|
|
||||||
|
fv, zero := v.Interface(), v.IsZero()
|
||||||
|
return fv, zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReflectValueOf
|
if field.Serializer != nil {
|
||||||
switch {
|
oldValuerOf := field.ValueOf
|
||||||
case len(field.StructField.Index) == 1:
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
|
value, zero := oldValuerOf(ctx, v)
|
||||||
return reflect.Indirect(value).Field(field.StructField.Index[0])
|
if zero {
|
||||||
|
return value, zero
|
||||||
}
|
}
|
||||||
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
|
|
||||||
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
|
s, ok := value.(SerializerValuerInterface)
|
||||||
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
|
if !ok {
|
||||||
|
s = field.Serializer
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serializer{
|
||||||
|
Field: field,
|
||||||
|
SerializeValuer: s,
|
||||||
|
Destination: v,
|
||||||
|
Context: ctx,
|
||||||
|
fieldValue: value,
|
||||||
|
}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReflectValueOf returns field's reflect value
|
||||||
|
switch {
|
||||||
|
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||||
|
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||||
|
return reflect.Indirect(value).Field(fieldIndex)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
|
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||||
v := reflect.Indirect(value)
|
v = reflect.Indirect(v)
|
||||||
for idx, fieldIdx := range field.StructField.Index {
|
for idx, fieldIdx := range field.StructField.Index {
|
||||||
if fieldIdx >= 0 {
|
if fieldIdx >= 0 {
|
||||||
v = v.Field(fieldIdx)
|
v = v.Field(fieldIdx)
|
||||||
} else {
|
} else {
|
||||||
v = v.Field(-fieldIdx - 1)
|
v = v.Field(-fieldIdx - 1)
|
||||||
}
|
|
||||||
|
|
||||||
if v.Kind() == reflect.Ptr {
|
|
||||||
if v.Type().Elem().Kind() == reflect.Struct {
|
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
v.Set(reflect.New(v.Type().Elem()))
|
v.Set(reflect.New(v.Type().Elem()))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if idx < len(field.StructField.Index)-1 {
|
if idx < len(field.StructField.Index)-1 {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
@ -483,22 +519,25 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
// Optimal value type acquisition for v
|
// Optimal value type acquisition for v
|
||||||
reflectValType := reflectV.Type()
|
reflectValType := reflectV.Type()
|
||||||
|
|
||||||
if reflectValType.AssignableTo(field.FieldType) {
|
if reflectValType.AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV)
|
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
|
||||||
|
reflectV = reflect.Indirect(reflectV)
|
||||||
|
}
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
return
|
return
|
||||||
} else if reflectValType.ConvertibleTo(field.FieldType) {
|
} else if reflectValType.ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
|
||||||
return
|
return
|
||||||
} else if field.FieldType.Kind() == reflect.Ptr {
|
} else if field.FieldType.Kind() == reflect.Ptr {
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
fieldType := field.FieldType.Elem()
|
fieldType := field.FieldType.Elem()
|
||||||
|
|
||||||
if reflectValType.AssignableTo(fieldType) {
|
if reflectValType.AssignableTo(fieldType) {
|
||||||
@ -521,16 +560,19 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
|
|
||||||
if reflectV.Kind() == reflect.Ptr {
|
if reflectV.Kind() == reflect.Ptr {
|
||||||
if reflectV.IsNil() {
|
if reflectV.IsNil() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
|
||||||
|
return
|
||||||
} else {
|
} else {
|
||||||
err = setter(value, reflectV.Elem().Interface())
|
err = setter(ctx, value, reflectV.Elem().Interface())
|
||||||
}
|
}
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||||
if v, err = valuer.Value(); err == nil {
|
if v, err = valuer.Value(); err == nil {
|
||||||
err = setter(value, v)
|
err = setter(ctx, value, v)
|
||||||
}
|
}
|
||||||
} else {
|
} else if _, ok := v.(clause.Expr); !ok {
|
||||||
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
|
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -540,191 +582,201 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
// Set
|
// Set
|
||||||
switch field.FieldType.Kind() {
|
switch field.FieldType.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
field.Set = func(value reflect.Value, v interface{}) error {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **bool:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetBool(**data)
|
||||||
|
}
|
||||||
case bool:
|
case bool:
|
||||||
field.ReflectValueOf(value).SetBool(data)
|
field.ReflectValueOf(ctx, value).SetBool(data)
|
||||||
case *bool:
|
|
||||||
if data != nil {
|
|
||||||
field.ReflectValueOf(value).SetBool(*data)
|
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(value).SetBool(false)
|
|
||||||
}
|
|
||||||
case int64:
|
case int64:
|
||||||
if data > 0 {
|
field.ReflectValueOf(ctx, value).SetBool(data > 0)
|
||||||
field.ReflectValueOf(value).SetBool(true)
|
|
||||||
} else {
|
|
||||||
field.ReflectValueOf(value).SetBool(false)
|
|
||||||
}
|
|
||||||
case string:
|
case string:
|
||||||
b, _ := strconv.ParseBool(data)
|
b, _ := strconv.ParseBool(data)
|
||||||
field.ReflectValueOf(value).SetBool(b)
|
field.ReflectValueOf(ctx, value).SetBool(b)
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **int64:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||||
|
}
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValueOf(value).SetInt(data)
|
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||||
case int:
|
case int:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case int8:
|
case int8:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case int16:
|
case int16:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case int32:
|
case int32:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case uint:
|
case uint:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case uint8:
|
case uint8:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case uint16:
|
case uint16:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case uint32:
|
case uint32:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case float32:
|
case float32:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValueOf(value).SetInt(int64(data))
|
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||||
case []byte:
|
case []byte:
|
||||||
return field.Set(value, string(data))
|
return field.Set(ctx, value, string(data))
|
||||||
case string:
|
case string:
|
||||||
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
|
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
|
||||||
field.ReflectValueOf(value).SetInt(i)
|
field.ReflectValueOf(ctx, value).SetInt(i)
|
||||||
} else {
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetInt(data.Unix())
|
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
if data != nil {
|
if data != nil {
|
||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano())
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
|
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetInt(data.Unix())
|
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetInt(0)
|
field.ReflectValueOf(ctx, value).SetInt(0)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **uint64:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||||
|
}
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValueOf(value).SetUint(data)
|
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||||
case uint:
|
case uint:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case uint8:
|
case uint8:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case uint16:
|
case uint16:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case uint32:
|
case uint32:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case int:
|
case int:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case int8:
|
case int8:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case int16:
|
case int16:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case int32:
|
case int32:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case float32:
|
case float32:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||||
case []byte:
|
case []byte:
|
||||||
return field.Set(value, string(data))
|
return field.Set(ctx, value, string(data))
|
||||||
case time.Time:
|
case time.Time:
|
||||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
|
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
||||||
field.ReflectValueOf(value).SetUint(i)
|
field.ReflectValueOf(ctx, value).SetUint(i)
|
||||||
} else {
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **float64:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||||
|
}
|
||||||
case float64:
|
case float64:
|
||||||
field.ReflectValueOf(value).SetFloat(data)
|
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||||
case float32:
|
case float32:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case int64:
|
case int64:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case int:
|
case int:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case int8:
|
case int8:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case int16:
|
case int16:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case int32:
|
case int32:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case uint:
|
case uint:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case uint8:
|
case uint8:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case uint16:
|
case uint16:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case uint32:
|
case uint32:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case uint64:
|
case uint64:
|
||||||
field.ReflectValueOf(value).SetFloat(float64(data))
|
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||||
case []byte:
|
case []byte:
|
||||||
return field.Set(value, string(data))
|
return field.Set(ctx, value, string(data))
|
||||||
case string:
|
case string:
|
||||||
if i, err := strconv.ParseFloat(data, 64); err == nil {
|
if i, err := strconv.ParseFloat(data, 64); err == nil {
|
||||||
field.ReflectValueOf(value).SetFloat(i)
|
field.ReflectValueOf(ctx, value).SetFloat(i)
|
||||||
} else {
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **string:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).SetString(**data)
|
||||||
|
}
|
||||||
case string:
|
case string:
|
||||||
field.ReflectValueOf(value).SetString(data)
|
field.ReflectValueOf(ctx, value).SetString(data)
|
||||||
case []byte:
|
case []byte:
|
||||||
field.ReflectValueOf(value).SetString(string(data))
|
field.ReflectValueOf(ctx, value).SetString(string(data))
|
||||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
field.ReflectValueOf(value).SetString(utils.ToString(data))
|
field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
|
||||||
case float64, float32:
|
case float64, float32:
|
||||||
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -732,41 +784,49 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
fieldValue := reflect.New(field.FieldType)
|
fieldValue := reflect.New(field.FieldType)
|
||||||
switch fieldValue.Elem().Interface().(type) {
|
switch fieldValue.Elem().Interface().(type) {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
field.Set = func(value reflect.Value, v interface{}) error {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **time.Time:
|
||||||
|
if data != nil && *data != nil {
|
||||||
|
field.Set(ctx, value, *data)
|
||||||
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
if data != nil {
|
if data != nil {
|
||||||
field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
|
||||||
} else {
|
} else {
|
||||||
field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{}))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
if t, err := now.Parse(data); err == nil {
|
if t, err := now.Parse(data); err == nil {
|
||||||
field.ReflectValueOf(value).Set(reflect.ValueOf(t))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
field.Set = func(value reflect.Value, v interface{}) error {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||||
switch data := v.(type) {
|
switch data := v.(type) {
|
||||||
|
case **time.Time:
|
||||||
|
if data != nil {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||||
|
}
|
||||||
case time.Time:
|
case time.Time:
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
fieldValue.Elem().Set(reflect.ValueOf(v))
|
fieldValue.Elem().Set(reflect.ValueOf(v))
|
||||||
case *time.Time:
|
case *time.Time:
|
||||||
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||||
case string:
|
case string:
|
||||||
if t, err := now.Parse(data); err == nil {
|
if t, err := now.Parse(data); err == nil {
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
if v == "" {
|
if v == "" {
|
||||||
return nil
|
return nil
|
||||||
@ -778,27 +838,27 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
// pointer scanner
|
// pointer scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
return field.Set(value, reflectV.Elem().Interface())
|
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(ctx, value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
@ -813,32 +873,80 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
}
|
}
|
||||||
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
// struct scanner
|
// struct scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if !reflectV.IsValid() {
|
if !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV)
|
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
return field.Set(value, reflectV.Elem().Interface())
|
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
v, _ = valuer.Value()
|
v, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(ctx, value, v, field.Set)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if field.Serializer != nil {
|
||||||
|
var (
|
||||||
|
oldFieldSetter = field.Set
|
||||||
|
sameElemType bool
|
||||||
|
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
|
||||||
|
)
|
||||||
|
|
||||||
|
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
|
||||||
|
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||||
|
if s, ok := v.(*serializer); ok {
|
||||||
|
if s.fieldValue != nil {
|
||||||
|
err = oldFieldSetter(ctx, value, s.fieldValue)
|
||||||
|
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||||
|
if sameElemType {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||||
|
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||||
|
} else if sameType {
|
||||||
|
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||||
|
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = oldFieldSetter(ctx, value, v)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (field *Field) setupNewValuePool() {
|
||||||
|
if field.Serializer != nil {
|
||||||
|
field.NewValuePool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &serializer{
|
||||||
|
Field: field,
|
||||||
|
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.NewValuePool == nil {
|
||||||
|
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues {
|
for k, v := range newValues {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues2 {
|
for k, v := range newValues2 {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues {
|
for k, v := range newValues {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues2 {
|
for k, v := range newValues2 {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues {
|
for k, v := range newValues {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range newValues2 {
|
for k, v := range newValues2 {
|
||||||
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
|
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
|
||||||
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
|||||||
|
|
||||||
for _, field := range schema.Fields {
|
for _, field := range schema.Fields {
|
||||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||||
for _, index := range parseFieldIndexes(field) {
|
fieldIndexes, err := parseFieldIndexes(field)
|
||||||
|
if err != nil {
|
||||||
|
schema.err = err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for _, index := range fieldIndexes {
|
||||||
idx := indexes[index.Name]
|
idx := indexes[index.Name]
|
||||||
idx.Name = index.Name
|
idx.Name = index.Name
|
||||||
if idx.Class == "" {
|
if idx.Class == "" {
|
||||||
@ -82,7 +88,7 @@ func (schema *Schema) LookIndex(name string) *Index {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFieldIndexes(field *Field) (indexes []Index) {
|
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||||
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||||
if value != "" {
|
if value != "" {
|
||||||
v := strings.Split(value, ":")
|
v := strings.Split(value, ":")
|
||||||
@ -92,7 +98,8 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
|
|||||||
name string
|
name string
|
||||||
tag = strings.Join(v[1:], ":")
|
tag = strings.Join(v[1:], ":")
|
||||||
idx = strings.Index(tag, ",")
|
idx = strings.Index(tag, ",")
|
||||||
settings = ParseTagSetting(tag, ",")
|
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
|
||||||
|
settings = ParseTagSetting(tagSetting, ",")
|
||||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
|
subName := field.Name
|
||||||
|
const key = "COMPOSITE"
|
||||||
|
if composite, found := settings[key]; found {
|
||||||
|
if len(composite) == 0 || composite == key {
|
||||||
|
err = fmt.Errorf(
|
||||||
|
"The composite tag of %s.%s cannot be empty",
|
||||||
|
field.Schema.Name,
|
||||||
|
field.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
subName = composite
|
||||||
|
}
|
||||||
|
name = field.Schema.namer.IndexName(
|
||||||
|
field.Schema.Table, subName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
||||||
@ -137,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,37 @@ type UserIndex struct {
|
|||||||
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
|
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
|
||||||
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
||||||
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
||||||
|
Name7 string `gorm:"index:type"`
|
||||||
|
|
||||||
|
// Composite Index: Flattened structure.
|
||||||
|
Data0A string `gorm:"index:,composite:comp_id0"`
|
||||||
|
Data0B string `gorm:"index:,composite:comp_id0"`
|
||||||
|
|
||||||
|
// Composite Index: Nested structure.
|
||||||
|
Data1A string `gorm:"index:,composite:comp_id1"`
|
||||||
|
CompIdxLevel1C
|
||||||
|
|
||||||
|
// Composite Index: Unique and priority.
|
||||||
|
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
|
||||||
|
CompIdxLevel2C
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompIdxLevel1C struct {
|
||||||
|
CompIdxLevel1B
|
||||||
|
Data1C string `gorm:"index:,composite:comp_id1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompIdxLevel1B struct {
|
||||||
|
Data1B string `gorm:"index:,composite:comp_id1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompIdxLevel2C struct {
|
||||||
|
CompIdxLevel2B
|
||||||
|
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompIdxLevel2B struct {
|
||||||
|
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseIndex(t *testing.T) {
|
func TestParseIndex(t *testing.T) {
|
||||||
@ -78,6 +109,41 @@ func TestParseIndex(t *testing.T) {
|
|||||||
Class: "UNIQUE",
|
Class: "UNIQUE",
|
||||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
||||||
},
|
},
|
||||||
|
"type": {
|
||||||
|
Name: "type",
|
||||||
|
Type: "",
|
||||||
|
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
||||||
|
},
|
||||||
|
"idx_user_indices_comp_id0": {
|
||||||
|
Name: "idx_user_indices_comp_id0",
|
||||||
|
Type: "",
|
||||||
|
Fields: []schema.IndexOption{{
|
||||||
|
Field: &schema.Field{Name: "Data0A"},
|
||||||
|
}, {
|
||||||
|
Field: &schema.Field{Name: "Data0B"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
"idx_user_indices_comp_id1": {
|
||||||
|
Name: "idx_user_indices_comp_id1",
|
||||||
|
Fields: []schema.IndexOption{{
|
||||||
|
Field: &schema.Field{Name: "Data1A"},
|
||||||
|
}, {
|
||||||
|
Field: &schema.Field{Name: "Data1B"},
|
||||||
|
}, {
|
||||||
|
Field: &schema.Field{Name: "Data1C"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
"idx_user_indices_comp_id2": {
|
||||||
|
Name: "idx_user_indices_comp_id2",
|
||||||
|
Class: "UNIQUE",
|
||||||
|
Fields: []schema.IndexOption{{
|
||||||
|
Field: &schema.Field{Name: "Data2C"},
|
||||||
|
}, {
|
||||||
|
Field: &schema.Field{Name: "Data2A"},
|
||||||
|
}, {
|
||||||
|
Field: &schema.Field{Name: "Data2B"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
indices := user.ParseIndexes()
|
indices := user.ParseIndexes()
|
||||||
|
@ -4,22 +4,33 @@ import (
|
|||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GormDataTypeInterface gorm data type interface
|
||||||
type GormDataTypeInterface interface {
|
type GormDataTypeInterface interface {
|
||||||
GormDataType() string
|
GormDataType() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FieldNewValuePool field new scan value pool
|
||||||
|
type FieldNewValuePool interface {
|
||||||
|
Get() interface{}
|
||||||
|
Put(interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateClausesInterface create clauses interface
|
||||||
type CreateClausesInterface interface {
|
type CreateClausesInterface interface {
|
||||||
CreateClauses(*Field) []clause.Interface
|
CreateClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryClausesInterface query clauses interface
|
||||||
type QueryClausesInterface interface {
|
type QueryClausesInterface interface {
|
||||||
QueryClauses(*Field) []clause.Interface
|
QueryClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateClausesInterface update clauses interface
|
||||||
type UpdateClausesInterface interface {
|
type UpdateClausesInterface interface {
|
||||||
UpdateClauses(*Field) []clause.Interface
|
UpdateClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteClausesInterface delete clauses interface
|
||||||
type DeleteClausesInterface interface {
|
type DeleteClausesInterface interface {
|
||||||
DeleteClauses(*Field) []clause.Interface
|
DeleteClauses(*Field) []clause.Interface
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package schema
|
|||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@ -86,16 +85,16 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||||
formattedName := strings.Replace(strings.Join([]string{
|
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||||
prefix, table, name,
|
prefix, table, name,
|
||||||
}, "_"), ".", "_", -1)
|
}, "_"), ".", "_")
|
||||||
|
|
||||||
if utf8.RuneCountInString(formattedName) > 64 {
|
if utf8.RuneCountInString(formattedName) > 64 {
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
h.Write([]byte(formattedName))
|
h.Write([]byte(formattedName))
|
||||||
bs := h.Sum(nil)
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8]
|
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
||||||
}
|
}
|
||||||
return formattedName
|
return formattedName
|
||||||
}
|
}
|
||||||
@ -120,7 +119,13 @@ func (ns NamingStrategy) toDBName(name string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ns.NameReplacer != nil {
|
if ns.NameReplacer != nil {
|
||||||
name = ns.NameReplacer.Replace(name)
|
tmpName := ns.NameReplacer.Replace(name)
|
||||||
|
|
||||||
|
if tmpName == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
name = tmpName
|
||||||
}
|
}
|
||||||
|
|
||||||
if ns.NoLowerCase {
|
if ns.NoLowerCase {
|
||||||
@ -168,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||||
result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
|
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||||
for _, initialism := range commonInitialisms {
|
for _, initialism := range commonInitialisms {
|
||||||
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||||
}
|
}
|
||||||
|
@ -193,7 +193,18 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
|||||||
ns := NamingStrategy{}
|
ns := NamingStrategy{}
|
||||||
|
|
||||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||||
if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" {
|
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||||
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplaceEmptyTableName(t *testing.T) {
|
||||||
|
ns := NamingStrategy{
|
||||||
|
SingularTable: true,
|
||||||
|
NameReplacer: strings.NewReplacer("Model", ""),
|
||||||
|
}
|
||||||
|
tableName := ns.TableName("Model")
|
||||||
|
if tableName != "Model" {
|
||||||
|
t.Errorf("invalid table name generated, got %v", tableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
19
schema/pool.go
Normal file
19
schema/pool.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sync pools
|
||||||
|
var (
|
||||||
|
normalPool sync.Map
|
||||||
|
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
|
||||||
|
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return reflect.New(reflectType).Interface()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return v.(FieldNewValuePool)
|
||||||
|
}
|
||||||
|
)
|
@ -1,6 +1,7 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -234,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
PkgPath: ownField.StructField.PkgPath,
|
PkgPath: ownField.StructField.PkgPath,
|
||||||
Type: ownField.StructField.Type,
|
Type: ownField.StructField.Type,
|
||||||
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
|
||||||
|
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -257,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
PkgPath: relField.StructField.PkgPath,
|
PkgPath: relField.StructField.PkgPath,
|
||||||
Type: relField.StructField.Type,
|
Type: relField.StructField.Type,
|
||||||
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
|
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
|
||||||
|
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,6 +418,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
var primaryFields []*Field
|
var primaryFields []*Field
|
||||||
|
var primarySchemaName = primarySchema.Name
|
||||||
|
if primarySchemaName == "" {
|
||||||
|
primarySchemaName = relation.FieldSchema.Name
|
||||||
|
}
|
||||||
|
|
||||||
if len(relation.primaryKeys) > 0 {
|
if len(relation.primaryKeys) > 0 {
|
||||||
for _, primaryKey := range relation.primaryKeys {
|
for _, primaryKey := range relation.primaryKeys {
|
||||||
@ -427,7 +434,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, primaryField := range primaryFields {
|
for _, primaryField := range primaryFields {
|
||||||
lookUpName := primarySchema.Name + primaryField.Name
|
lookUpName := primarySchemaName + primaryField.Name
|
||||||
if gl == guessBelongs {
|
if gl == guessBelongs {
|
||||||
lookUpName = field.Name + primaryField.Name
|
lookUpName = field.Name + primaryField.Name
|
||||||
}
|
}
|
||||||
@ -576,7 +583,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
|||||||
return &constraint
|
return &constraint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
|
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
|
||||||
table := rel.FieldSchema.Table
|
table := rel.FieldSchema.Table
|
||||||
foreignFields := []*Field{}
|
foreignFields := []*Field{}
|
||||||
relForeignKeys := []string{}
|
relForeignKeys := []string{}
|
||||||
@ -616,7 +623,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
|
||||||
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
|
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
|
@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVariableRelation(t *testing.T) {
|
||||||
|
var result struct {
|
||||||
|
User
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &result, Relation{
|
||||||
|
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
|
||||||
|
References: []Reference{
|
||||||
|
{"ID", "", "UserID", "Account", "", true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
checkStructRelation(t, &result, Relation{
|
||||||
|
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
|
||||||
|
References: []Reference{
|
||||||
|
{"ID", "Company", "CompanyID", "", "", false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSameForeignKey(t *testing.T) {
|
func TestSameForeignKey(t *testing.T) {
|
||||||
type UserAux struct {
|
type UserAux struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
@ -576,3 +596,39 @@ func TestHasManySameForeignKey(t *testing.T) {
|
|||||||
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
|
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Author struct {
|
||||||
|
gorm.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
type Book struct {
|
||||||
|
gorm.Model
|
||||||
|
Author Author
|
||||||
|
AuthorID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Book) TableName() string {
|
||||||
|
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
||||||
|
s, err := schema.Parse(
|
||||||
|
&Book{},
|
||||||
|
&sync.Map{},
|
||||||
|
schema.NamingStrategy{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
|
||||||
|
constraint := s.Relationships.Relations["Author"].ParseConstraint()
|
||||||
|
|
||||||
|
if constraint.Name != expectedConstraintName {
|
||||||
|
t.Fatalf(
|
||||||
|
"expected constraint name %s, got %s",
|
||||||
|
expectedConstraintName,
|
||||||
|
constraint.Name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package schema_test
|
package schema_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||||
for k, v := range values {
|
for k, v := range values {
|
||||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||||
fv, _ := s.FieldsByDBName[k].ValueOf(value)
|
fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value)
|
||||||
tests.AssertEqual(t, v, fv)
|
tests.AssertEqual(t, v, fv)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
157
schema/serializer.go
Normal file
157
schema/serializer.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/gob"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serializerMap = sync.Map{}
|
||||||
|
|
||||||
|
// RegisterSerializer register serializer
|
||||||
|
func RegisterSerializer(name string, serializer SerializerInterface) {
|
||||||
|
serializerMap.Store(strings.ToLower(name), serializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSerializer get serializer
|
||||||
|
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
|
||||||
|
v, ok := serializerMap.Load(strings.ToLower(name))
|
||||||
|
if ok {
|
||||||
|
serializer, ok = v.(SerializerInterface)
|
||||||
|
}
|
||||||
|
return serializer, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterSerializer("json", JSONSerializer{})
|
||||||
|
RegisterSerializer("unixtime", UnixSecondSerializer{})
|
||||||
|
RegisterSerializer("gob", GobSerializer{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serializer field value serializer
|
||||||
|
type serializer struct {
|
||||||
|
Field *Field
|
||||||
|
Serializer SerializerInterface
|
||||||
|
SerializeValuer SerializerValuerInterface
|
||||||
|
Destination reflect.Value
|
||||||
|
Context context.Context
|
||||||
|
value interface{}
|
||||||
|
fieldValue interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements sql.Scanner interface
|
||||||
|
func (s *serializer) Scan(value interface{}) error {
|
||||||
|
s.value = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements driver.Valuer interface
|
||||||
|
func (s serializer) Value() (driver.Value, error) {
|
||||||
|
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializerInterface serializer interface
|
||||||
|
type SerializerInterface interface {
|
||||||
|
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
|
||||||
|
SerializerValuerInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializerValuerInterface serializer valuer interface
|
||||||
|
type SerializerValuerInterface interface {
|
||||||
|
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONSerializer json serializer
|
||||||
|
type JSONSerializer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements serializer interface
|
||||||
|
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
fieldValue := reflect.New(field.FieldType)
|
||||||
|
|
||||||
|
if dbValue != nil {
|
||||||
|
var bytes []byte
|
||||||
|
switch v := dbValue.(type) {
|
||||||
|
case []byte:
|
||||||
|
bytes = v
|
||||||
|
case string:
|
||||||
|
bytes = []byte(v)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements serializer interface
|
||||||
|
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
|
result, err := json.Marshal(fieldValue)
|
||||||
|
return string(result), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnixSecondSerializer json serializer
|
||||||
|
type UnixSecondSerializer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements serializer interface
|
||||||
|
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
t := sql.NullTime{}
|
||||||
|
if err = t.Scan(dbValue); err == nil && t.Valid {
|
||||||
|
err = field.Set(ctx, dst, t.Time.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements serializer interface
|
||||||
|
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||||
|
switch v := fieldValue.(type) {
|
||||||
|
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||||
|
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GobSerializer gob serializer
|
||||||
|
type GobSerializer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements serializer interface
|
||||||
|
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
fieldValue := reflect.New(field.FieldType)
|
||||||
|
|
||||||
|
if dbValue != nil {
|
||||||
|
var bytesValue []byte
|
||||||
|
switch v := dbValue.(type) {
|
||||||
|
case []byte:
|
||||||
|
bytesValue = v
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||||
|
}
|
||||||
|
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||||
|
err = decoder.Decode(fieldValue.Interface())
|
||||||
|
}
|
||||||
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements serializer interface
|
||||||
|
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
err := gob.NewEncoder(buf).Encode(fieldValue)
|
||||||
|
return buf.Bytes(), err
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package schema
|
package schema
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -58,14 +60,22 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
|
|||||||
return tag
|
return tag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
|
||||||
|
t := tag.Get("gorm")
|
||||||
|
if strings.Contains(t, value) {
|
||||||
|
return tag
|
||||||
|
}
|
||||||
|
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
|
||||||
|
}
|
||||||
|
|
||||||
// GetRelationsValues get relations's values from a reflect value
|
// GetRelationsValues get relations's values from a reflect value
|
||||||
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||||
for _, rel := range rels {
|
for _, rel := range rels {
|
||||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
|
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||||
|
|
||||||
appendToResults := func(value reflect.Value) {
|
appendToResults := func(value reflect.Value) {
|
||||||
if _, isZero := rel.Field.ValueOf(value); !isZero {
|
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
|
||||||
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
|
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
|
||||||
switch result.Kind() {
|
switch result.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
reflectResults = reflect.Append(reflectResults, result.Addr())
|
reflectResults = reflect.Append(reflectResults, result.Addr())
|
||||||
@ -97,7 +107,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentityFieldValuesMap get identity map from fields
|
// GetIdentityFieldValuesMap get identity map from fields
|
||||||
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||||
var (
|
var (
|
||||||
results = [][]interface{}{}
|
results = [][]interface{}{}
|
||||||
dataResults = map[string][]reflect.Value{}
|
dataResults = map[string][]reflect.Value{}
|
||||||
@ -110,7 +120,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
|
|||||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||||
|
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
results[0][idx], zero = field.ValueOf(reflectValue)
|
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
|
||||||
notZero = notZero || !zero
|
notZero = notZero || !zero
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +145,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
|
|||||||
fieldValues := make([]interface{}, len(fields))
|
fieldValues := make([]interface{}, len(fields))
|
||||||
notZero = false
|
notZero = false
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
fieldValues[idx], zero = field.ValueOf(elem)
|
fieldValues[idx], zero = field.ValueOf(ctx, elem)
|
||||||
notZero = notZero || !zero
|
notZero = notZero || !zero
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,12 +165,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentityFieldValuesMapFromValues get identity map from fields
|
// GetIdentityFieldValuesMapFromValues get identity map from fields
|
||||||
func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||||
resultsMap := map[string][]reflect.Value{}
|
resultsMap := map[string][]reflect.Value{}
|
||||||
results := [][]interface{}{}
|
results := [][]interface{}{}
|
||||||
|
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields)
|
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
|
||||||
for k, v := range rm {
|
for k, v := range rm {
|
||||||
resultsMap[k] = append(resultsMap[k], v...)
|
resultsMap[k] = append(resultsMap[k], v...)
|
||||||
}
|
}
|
||||||
|
@ -104,10 +104,8 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
|
|||||||
|
|
||||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||||
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
|
|
||||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||||
@ -135,7 +133,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
|||||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||||
|
|
||||||
if stmt.Schema != nil {
|
if stmt.Schema != nil {
|
||||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||||
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
@ -143,7 +141,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
|
|
||||||
stmt.DB.AddError(ErrMissingWhereClause)
|
|
||||||
} else {
|
|
||||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||||
}
|
|
||||||
|
|
||||||
stmt.AddClauseIfNotExists(clause.Update{})
|
stmt.AddClauseIfNotExists(clause.Update{})
|
||||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||||
}
|
}
|
||||||
|
37
statement.go
37
statement.go
@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||||||
writer.WriteByte('(')
|
writer.WriteByte('(')
|
||||||
for idx, d := range v {
|
for idx, d := range v {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
writer.WriteString(",")
|
writer.WriteByte(',')
|
||||||
}
|
}
|
||||||
stmt.QuoteTo(writer, d)
|
stmt.QuoteTo(writer, d)
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
|||||||
writer.WriteByte('(')
|
writer.WriteByte('(')
|
||||||
for idx, d := range v {
|
for idx, d := range v {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
writer.WriteString(",")
|
writer.WriteByte(',')
|
||||||
}
|
}
|
||||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||||
}
|
}
|
||||||
@ -179,9 +179,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
} else {
|
} else {
|
||||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||||
}
|
}
|
||||||
case clause.Expr:
|
case clause.Expression:
|
||||||
v.Build(stmt)
|
|
||||||
case *clause.Expr:
|
|
||||||
v.Build(stmt)
|
v.Build(stmt)
|
||||||
case driver.Valuer:
|
case driver.Valuer:
|
||||||
stmt.Vars = append(stmt.Vars, v)
|
stmt.Vars = append(stmt.Vars, v)
|
||||||
@ -314,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
case *DB:
|
case *DB:
|
||||||
|
for _, scope := range v.Statement.scopes {
|
||||||
|
v = scope(v)
|
||||||
|
}
|
||||||
|
|
||||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||||
if where, ok := cs.Expression.(clause.Where); ok {
|
if where, ok := cs.Expression.(clause.Where); ok {
|
||||||
if len(where.Exprs) == 1 {
|
if len(where.Exprs) == 1 {
|
||||||
@ -391,7 +393,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
@ -405,7 +407,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
for _, field := range s.Fields {
|
for _, field := range s.Fields {
|
||||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
@ -564,7 +566,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
|
|||||||
|
|
||||||
switch destValue.Kind() {
|
switch destValue.Kind() {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
field.Set(destValue, value)
|
stmt.AddError(field.Set(stmt.Context, destValue, value))
|
||||||
default:
|
default:
|
||||||
stmt.AddError(ErrInvalidData)
|
stmt.AddError(ErrInvalidData)
|
||||||
}
|
}
|
||||||
@ -574,10 +576,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
|
|||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if len(fromCallbacks) > 0 {
|
if len(fromCallbacks) > 0 {
|
||||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||||
field.Set(stmt.ReflectValue.Index(i), value)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
if !stmt.ReflectValue.CanAddr() {
|
if !stmt.ReflectValue.CanAddr() {
|
||||||
@ -585,7 +587,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
field.Set(stmt.ReflectValue, value)
|
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
stmt.AddError(ErrInvalidField)
|
stmt.AddError(ErrInvalidField)
|
||||||
@ -605,12 +607,12 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
|||||||
|
|
||||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||||
changed := func(field *schema.Field) bool {
|
changed := func(field *schema.Field) bool {
|
||||||
fieldValue, _ := field.ValueOf(modelValue)
|
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
|
||||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
|
||||||
if fv, ok := v[field.Name]; ok {
|
if fv, ok := mv[field.Name]; ok {
|
||||||
return !utils.AssertEqual(fv, fieldValue)
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
} else if fv, ok := v[field.DBName]; ok {
|
} else if fv, ok := mv[field.DBName]; ok {
|
||||||
return !utils.AssertEqual(fv, fieldValue)
|
return !utils.AssertEqual(fv, fieldValue)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -619,7 +621,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
|||||||
destValue = destValue.Elem()
|
destValue = destValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
changedValue, zero := field.ValueOf(destValue)
|
changedValue, zero := field.ValueOf(stmt.Context, destValue)
|
||||||
|
if v {
|
||||||
|
return !utils.AssertEqual(changedValue, fieldValue)
|
||||||
|
}
|
||||||
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -220,3 +220,67 @@ func TestFullSaveAssociations(t *testing.T) {
|
|||||||
t.Errorf("Failed to preload AppliesToProduct")
|
t.Errorf("Failed to preload AppliesToProduct")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSaveBelongsCircularReference(t *testing.T) {
|
||||||
|
parent := Parent{}
|
||||||
|
DB.Create(&parent)
|
||||||
|
|
||||||
|
child := Child{ParentID: &parent.ID, Parent: &parent}
|
||||||
|
DB.Create(&child)
|
||||||
|
|
||||||
|
parent.FavChildID = child.ID
|
||||||
|
parent.FavChild = &child
|
||||||
|
DB.Save(&parent)
|
||||||
|
|
||||||
|
var parent1 Parent
|
||||||
|
DB.First(&parent1, parent.ID)
|
||||||
|
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||||
|
|
||||||
|
// Save and Updates is the same
|
||||||
|
DB.Updates(&parent)
|
||||||
|
DB.First(&parent1, parent.ID)
|
||||||
|
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveHasManyCircularReference(t *testing.T) {
|
||||||
|
parent := Parent{}
|
||||||
|
DB.Create(&parent)
|
||||||
|
|
||||||
|
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
|
||||||
|
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
|
||||||
|
|
||||||
|
parent.Children = []*Child{&child, &child1}
|
||||||
|
DB.Save(&parent)
|
||||||
|
|
||||||
|
var children []*Child
|
||||||
|
DB.Where("parent_id = ?", parent.ID).Find(&children)
|
||||||
|
if len(children) != len(parent.Children) ||
|
||||||
|
children[0].ID != parent.Children[0].ID ||
|
||||||
|
children[1].ID != parent.Children[1].ID {
|
||||||
|
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
|
||||||
|
children, parent.Children)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssociationError(t *testing.T) {
|
||||||
|
user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var user1 User
|
||||||
|
DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1)
|
||||||
|
|
||||||
|
var emptyUser User
|
||||||
|
var err error
|
||||||
|
// belongs to
|
||||||
|
err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company)
|
||||||
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
|
// has many
|
||||||
|
err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets)
|
||||||
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
|
// has one
|
||||||
|
err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account)
|
||||||
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
|
// many to many
|
||||||
|
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
||||||
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
|
}
|
||||||
|
@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
|
|||||||
func c3(*gorm.DB) {}
|
func c3(*gorm.DB) {}
|
||||||
func c4(*gorm.DB) {}
|
func c4(*gorm.DB) {}
|
||||||
func c5(*gorm.DB) {}
|
func c5(*gorm.DB) {}
|
||||||
|
func c6(*gorm.DB) {}
|
||||||
|
|
||||||
func TestCallbacks(t *testing.T) {
|
func TestCallbacks(t *testing.T) {
|
||||||
type callback struct {
|
type callback struct {
|
||||||
@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPluginCallbacks(t *testing.T) {
|
||||||
|
db, _ := gorm.Open(nil, nil)
|
||||||
|
createCallback := db.Callback().Create()
|
||||||
|
|
||||||
|
createCallback.Before("*").Register("plugin_1_fn1", c1)
|
||||||
|
createCallback.After("*").Register("plugin_1_fn2", c2)
|
||||||
|
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// plugin 2
|
||||||
|
createCallback.Before("*").Register("plugin_2_fn1", c3)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.After("*").Register("plugin_2_fn2", c4)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// plugin 3
|
||||||
|
createCallback.Before("*").Register("plugin_3_fn1", c5)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
createCallback.After("*").Register("plugin_3_fn2", c6)
|
||||||
|
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
|
||||||
|
t.Errorf("callbacks tests failed, got %v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestWithSingleConnection(t *testing.T) {
|
func TestWithSingleConnection(t *testing.T) {
|
||||||
var expectedName = "test"
|
expectedName := "test"
|
||||||
var actualName string
|
var actualName string
|
||||||
|
|
||||||
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
|
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
|
||||||
@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
|
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
|
||||||
}
|
}
|
||||||
|
171
tests/connpool_test.go
Normal file
171
tests/connpool_test.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type wrapperTx struct {
|
||||||
|
*sql.Tx
|
||||||
|
conn *wrapperConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.conn.got = append(c.conn.got, query)
|
||||||
|
return c.Tx.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapperConnPool struct {
|
||||||
|
db *sql.DB
|
||||||
|
got []string
|
||||||
|
expect []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) Ping() error {
|
||||||
|
return c.db.Ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||||
|
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||||
|
// return c.db.BeginTx(ctx, opts)
|
||||||
|
// }
|
||||||
|
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||||
|
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||||
|
tx, err := c.db.BeginTx(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wrapperTx{Tx: tx, conn: c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.PrepareContext(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
c.got = append(c.got, query)
|
||||||
|
return c.db.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnPoolWrapper(t *testing.T) {
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
if dialect != "mysql" {
|
||||||
|
t.SkipNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
dbDSN := os.Getenv("GORM_DSN")
|
||||||
|
if dbDSN == "" {
|
||||||
|
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||||
|
}
|
||||||
|
nativeDB, err := sql.Open("mysql", dbDSN)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &wrapperConnPool{
|
||||||
|
db: nativeDB,
|
||||||
|
expect: []string{
|
||||||
|
"SELECT VERSION()",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if !reflect.DeepEqual(conn.got, conn.expect) {
|
||||||
|
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
user := *GetUser("transaction", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := *GetUser("transaction1-1", Config{})
|
||||||
|
|
||||||
|
if err = tx.Save(&user1).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||||
|
t.Fatalf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
|
||||||
|
t.Fatalf("Should not find record after rollback, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
txDB := db.Where("fake_name = ?", "fake_name")
|
||||||
|
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
|
||||||
|
user2 := *GetUser("transaction-2", Config{})
|
||||||
|
if err = tx2.Save(&user2).Error; err != nil {
|
||||||
|
t.Fatalf("No error should raise, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2.Commit()
|
||||||
|
|
||||||
|
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
|
||||||
|
t.Fatalf("Should be able to find committed record, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
@ -144,4 +144,22 @@ func TestCount(t *testing.T) {
|
|||||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
||||||
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count12 int64
|
||||||
|
if err := DB.Table("users").
|
||||||
|
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
|
||||||
|
Preload("Toys", func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Table("toys").Select("name")
|
||||||
|
}).Count(&count12).Error; err == nil {
|
||||||
|
t.Errorf("error should raise when using preload without schema")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count13 int64
|
||||||
|
if err := DB.Model(User{}).
|
||||||
|
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
|
||||||
|
Preload("Toys", func(db *gorm.DB) *gorm.DB {
|
||||||
|
return db.Table("toys").Select("name")
|
||||||
|
}).Count(&count13).Error; err != nil {
|
||||||
|
t.Errorf("no error should raise when using count with preload, but got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
|
|||||||
{"name": "create_from_map_3", "Age": 20},
|
{"name": "create_from_map_3", "Age": 20},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
|
if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
|
||||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) {
|
|||||||
t.Fatalf("it is not ErrInvalidValue")
|
t.Fatalf("it is not ErrInvalidValue")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirstOrCreateRowsAffected(t *testing.T) {
|
||||||
|
user := User{Name: "TestFirstOrCreateRowsAffected"}
|
||||||
|
|
||||||
|
res := DB.FirstOrCreate(&user, "name = ?", user.Name)
|
||||||
|
if res.Error != nil || res.RowsAffected != 1 {
|
||||||
|
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = DB.FirstOrCreate(&user, "name = ?", user.Name)
|
||||||
|
if res.Error != nil || res.RowsAffected != 0 {
|
||||||
|
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@ -14,6 +15,7 @@ func TestDefaultValue(t *testing.T) {
|
|||||||
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
||||||
Name3 string `gorm:"size:233;notNull;default:''"`
|
Name3 string `gorm:"size:233;notNull;default:''"`
|
||||||
Age int `gorm:"default:18"`
|
Age int `gorm:"default:18"`
|
||||||
|
Created time.Time `gorm:"default:2000-01-02"`
|
||||||
Enabled bool `gorm:"default:true"`
|
Enabled bool `gorm:"default:true"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) {
|
|||||||
harumph := Harumph{Email: "hello@gorm.io"}
|
harumph := Harumph{Email: "hello@gorm.io"}
|
||||||
if err := DB.Create(&harumph).Error; err != nil {
|
if err := DB.Create(&harumph).Error; err != nil {
|
||||||
t.Fatalf("Failed to create data with default value, got error: %v", err)
|
t.Fatalf("Failed to create data with default value, got error: %v", err)
|
||||||
} else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled {
|
} else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" {
|
||||||
t.Fatalf("Failed to create data with default value, got: %+v", harumph)
|
t.Fatalf("Failed to create data with default value, got: %+v", harumph)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result Harumph
|
var result Harumph
|
||||||
if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil {
|
if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil {
|
||||||
t.Fatalf("Failed to find created data, got error: %v", err)
|
t.Fatalf("Failed to find created data, got error: %v", err)
|
||||||
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled {
|
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
|
||||||
t.Fatalf("Failed to find created data with default data, got %+v", result)
|
t.Fatalf("Failed to find created data with default data, got %+v", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ version: '3'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
mysql:
|
mysql:
|
||||||
image: 'mysql:latest'
|
image: 'mysql/mysql-server:latest'
|
||||||
ports:
|
ports:
|
||||||
- 9910:3306
|
- 9910:3306
|
||||||
environment:
|
environment:
|
||||||
@ -20,7 +20,7 @@ services:
|
|||||||
- POSTGRES_USER=gorm
|
- POSTGRES_USER=gorm
|
||||||
- POSTGRES_PASSWORD=gorm
|
- POSTGRES_PASSWORD=gorm
|
||||||
mssql:
|
mssql:
|
||||||
image: 'mcmoe/mssqldocker:latest'
|
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
||||||
ports:
|
ports:
|
||||||
- 9930:1433
|
- 9930:1433
|
||||||
environment:
|
environment:
|
||||||
|
18
tests/go.mod
18
tests/go.mod
@ -3,16 +3,16 @@ module gorm.io/gorm/tests
|
|||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/jackc/pgx/v4 v4.14.1 // indirect
|
github.com/jinzhu/now v1.1.5
|
||||||
github.com/jinzhu/now v1.1.4
|
github.com/lib/pq v1.10.5
|
||||||
github.com/lib/pq v1.10.4
|
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
|
||||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect
|
gorm.io/driver/mysql v1.3.3
|
||||||
gorm.io/driver/mysql v1.2.3
|
gorm.io/driver/postgres v1.3.5
|
||||||
gorm.io/driver/postgres v1.2.3
|
gorm.io/driver/sqlite v1.3.2
|
||||||
gorm.io/driver/sqlite v1.2.6
|
gorm.io/driver/sqlserver v1.3.2
|
||||||
gorm.io/driver/sqlserver v1.2.1
|
gorm.io/gorm v1.23.4
|
||||||
gorm.io/gorm v1.22.4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
@ -19,6 +19,7 @@ type Config struct {
|
|||||||
Team int
|
Team int
|
||||||
Languages int
|
Languages int
|
||||||
Friends int
|
Friends int
|
||||||
|
NamedPet bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(name string, config Config) *User {
|
func GetUser(name string, config Config) *User {
|
||||||
@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User {
|
|||||||
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
|
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.NamedPet {
|
||||||
|
user.NamedPet = &Pet{Name: name + "_namepet"}
|
||||||
|
}
|
||||||
|
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) {
|
|||||||
t.Errorf("invalid data after update, got %+v", product)
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Code changed, price should changed
|
||||||
|
DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""})
|
||||||
|
if product.Name != "Product New4" || product.Price != 320 || product.Code != "" {
|
||||||
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
|
}
|
||||||
|
|
||||||
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
|
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
|
||||||
if product.Price != 270 || product.Code != "L1215" {
|
if product.Price != 320 || product.Code != "L1215" {
|
||||||
t.Errorf("invalid data after update, got %+v", product)
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"})
|
DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"})
|
||||||
if product.Price != 270 || product.Code != "L1216" {
|
if product.Price != 320 || product.Code != "L1216" {
|
||||||
t.Errorf("invalid data after update, got %+v", product)
|
t.Errorf("invalid data after update, got %+v", product)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -462,6 +468,7 @@ type ProductItem struct {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
Code string
|
Code string
|
||||||
Product4ID uint
|
Product4ID uint
|
||||||
|
AfterFindCallTimes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
||||||
@ -471,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pi *ProductItem) AfterFind(*gorm.DB) error {
|
||||||
|
pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
||||||
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
|
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
|
||||||
DB.AutoMigrate(&Product4{}, &ProductItem{})
|
DB.AutoMigrate(&Product4{}, &ProductItem{})
|
||||||
@ -492,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
|
|||||||
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
|
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
|
||||||
t.Errorf("should find product, but got error %v", err)
|
t.Errorf("should find product, but got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var productWithItem Product4
|
||||||
|
if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil {
|
||||||
|
t.Errorf("should find product, but got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if productWithItem.Item.AfterFindCallTimes != 0 {
|
||||||
|
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,12 +10,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestJoins(t *testing.T) {
|
func TestJoins(t *testing.T) {
|
||||||
user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true})
|
user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
|
||||||
|
|
||||||
DB.Create(&user)
|
DB.Create(&user)
|
||||||
|
|
||||||
var user2 User
|
var user2 User
|
||||||
if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
|
if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
|
||||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJoinWithOmit(t *testing.T) {
|
||||||
|
user := *GetUser("joins_with_omit", Config{Pets: 2})
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
results := make([]*User, 0)
|
||||||
|
|
||||||
|
if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 2 || results[0].Name != "" || results[1].Name != "" {
|
||||||
|
t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestJoinCount(t *testing.T) {
|
func TestJoinCount(t *testing.T) {
|
||||||
companyA := Company{Name: "A"}
|
companyA := Company{Name: "A"}
|
||||||
companyB := Company{Name: "B"}
|
companyB := Company{Name: "B"}
|
||||||
@ -184,3 +200,32 @@ func TestJoinCount(t *testing.T) {
|
|||||||
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
|
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJoinWithSoftDeleted(t *testing.T) {
|
||||||
|
user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var user1 User
|
||||||
|
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID)
|
||||||
|
if user1.NamedPet == nil || user1.Account.ID == 0 {
|
||||||
|
t.Fatalf("joins NamedPet and Account should not empty:%v", user1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Account should empty
|
||||||
|
DB.Delete(&user1.Account)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID)
|
||||||
|
if user2.NamedPet == nil || user2.Account.ID != 0 {
|
||||||
|
t.Fatalf("joins Account should not empty:%v", user2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NamedPet should empty
|
||||||
|
DB.Delete(&user1.NamedPet)
|
||||||
|
|
||||||
|
var user3 User
|
||||||
|
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID)
|
||||||
|
if user3.NamedPet != nil || user2.Account.ID != 0 {
|
||||||
|
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
|
|||||||
func TestSetAndGet(t *testing.T) {
|
func TestSetAndGet(t *testing.T) {
|
||||||
if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
|
if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
|
||||||
t.Errorf("Should be able to get setting after set")
|
t.Errorf("Should be able to get setting after set")
|
||||||
} else {
|
} else if value.(string) != "world" {
|
||||||
if value.(string) != "world" {
|
t.Errorf("Set value should not be changed")
|
||||||
t.Errorf("Setted value should not be changed")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := DB.Get("non_existing"); ok {
|
if _, ok := DB.Get("non_existing"); ok {
|
||||||
|
@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) {
|
|||||||
|
|
||||||
for _, m := range allModels {
|
for _, m := range allModels {
|
||||||
if !DB.Migrator().HasTable(m) {
|
if !DB.Migrator().HasTable(m) {
|
||||||
t.Fatalf("Failed to create table for %#v---", m)
|
t.Fatalf("Failed to create table for %#v", m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSmartMigrateColumn(t *testing.T) {
|
func TestSmartMigrateColumn(t *testing.T) {
|
||||||
fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()]
|
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
|
||||||
|
|
||||||
type UserMigrateColumn struct {
|
type UserMigrateColumn struct {
|
||||||
ID uint
|
ID uint
|
||||||
@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) {
|
|||||||
DB.Migrator().DropTable("new_table_structs")
|
DB.Migrator().DropTable("new_table_structs")
|
||||||
|
|
||||||
if DB.Migrator().HasTable(&NewTableStruct{}) {
|
if DB.Migrator().HasTable(&NewTableStruct{}) {
|
||||||
t.Fatal("should not found droped table")
|
t.Fatal("should not found dropped table")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,9 +313,16 @@ func TestMigrateIndexes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestMigrateColumns(t *testing.T) {
|
func TestMigrateColumns(t *testing.T) {
|
||||||
|
sqlite := DB.Dialector.Name() == "sqlite"
|
||||||
|
sqlserver := DB.Dialector.Name() == "sqlserver"
|
||||||
|
|
||||||
type ColumnStruct struct {
|
type ColumnStruct struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string
|
Name string
|
||||||
|
Age int `gorm:"default:18;comment:my age"`
|
||||||
|
Code string `gorm:"unique;comment:my code;"`
|
||||||
|
Code2 string
|
||||||
|
Code3 string `gorm:"unique"`
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Migrator().DropTable(&ColumnStruct{})
|
DB.Migrator().DropTable(&ColumnStruct{})
|
||||||
@ -327,12 +334,19 @@ func TestMigrateColumns(t *testing.T) {
|
|||||||
type ColumnStruct2 struct {
|
type ColumnStruct2 struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"size:100"`
|
Name string `gorm:"size:100"`
|
||||||
|
Code string `gorm:"unique;comment:my code2;default:hello"`
|
||||||
|
Code2 string `gorm:"unique"`
|
||||||
|
// Code3 string
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil {
|
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
|
||||||
t.Fatalf("no error should happened when alter column, but got %v", err)
|
t.Fatalf("no error should happened when alter column, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
|
||||||
|
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
|
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
|
||||||
t.Fatalf("no error should returns for ColumnTypes")
|
t.Fatalf("no error should returns for ColumnTypes")
|
||||||
} else {
|
} else {
|
||||||
@ -340,11 +354,45 @@ func TestMigrateColumns(t *testing.T) {
|
|||||||
stmt.Parse(&ColumnStruct2{})
|
stmt.Parse(&ColumnStruct2{})
|
||||||
|
|
||||||
for _, columnType := range columnTypes {
|
for _, columnType := range columnTypes {
|
||||||
if columnType.Name() == "name" {
|
switch columnType.Name() {
|
||||||
|
case "id":
|
||||||
|
if v, ok := columnType.PrimaryKey(); !ok || !v {
|
||||||
|
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
case "name":
|
||||||
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
|
||||||
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
|
||||||
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
|
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
|
||||||
}
|
}
|
||||||
|
if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
|
||||||
|
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
|
||||||
|
}
|
||||||
|
case "age":
|
||||||
|
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
|
||||||
|
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
|
||||||
|
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
case "code":
|
||||||
|
if v, ok := columnType.Unique(); !ok || !v {
|
||||||
|
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
|
||||||
|
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
|
||||||
|
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
case "code2":
|
||||||
|
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
|
||||||
|
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
}
|
||||||
|
case "code3":
|
||||||
|
// TODO
|
||||||
|
// if v, ok := columnType.Unique(); !ok || v {
|
||||||
|
// t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -526,3 +574,122 @@ func TestMigrateColumnOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/5047
|
||||||
|
func TestMigrateSerialColumn(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
ID uint `gorm:"primarykey"`
|
||||||
|
UID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Event1 struct {
|
||||||
|
ID uint `gorm:"primarykey"`
|
||||||
|
UID uint32 `gorm:"not null;autoIncrement"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Event2 struct {
|
||||||
|
ID uint `gorm:"primarykey"`
|
||||||
|
UID uint16 `gorm:"not null;autoIncrement"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
err = DB.Migrator().DropTable(&Event{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("DropTable err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create sequence
|
||||||
|
err = DB.Table("events").AutoMigrate(&Event1{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete sequence
|
||||||
|
err = DB.Table("events").AutoMigrate(&Event{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update sequence
|
||||||
|
err = DB.Table("events").AutoMigrate(&Event1{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
err = DB.Table("events").AutoMigrate(&Event2{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Table("events").Save(&Event2{})
|
||||||
|
DB.Table("events").Save(&Event2{})
|
||||||
|
DB.Table("events").Save(&Event2{})
|
||||||
|
|
||||||
|
events := make([]*Event, 0)
|
||||||
|
DB.Table("events").Find(&events)
|
||||||
|
|
||||||
|
AssertEqual(t, 3, len(events))
|
||||||
|
for _, v := range events {
|
||||||
|
AssertEqual(t, v.ID, v.UID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/5300
|
||||||
|
func TestMigrateWithSpecialName(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
err = DB.AutoMigrate(&Coupon{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, true, DB.Migrator().HasTable("coupons"))
|
||||||
|
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1"))
|
||||||
|
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/5320
|
||||||
|
func TestPrimarykeyID(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type MissPKLanguage struct {
|
||||||
|
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type MissPKUser struct {
|
||||||
|
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
|
||||||
|
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DropTable err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
|
||||||
|
|
||||||
|
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// patch
|
||||||
|
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AutoMigrate err:%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
@ -17,6 +18,8 @@ func TestPostgres(t *testing.T) {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"check:name_checker,name <> ''"`
|
Name string `gorm:"check:name_checker,name <> ''"`
|
||||||
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
||||||
|
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||||
|
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||||
Things pq.StringArray `gorm:"type:text[]"`
|
Things pq.StringArray `gorm:"type:text[]"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) {
|
|||||||
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
|
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
|
||||||
t.Errorf("No error should happen, but got %v", err)
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
harumph.Name = "jinzhu1"
|
||||||
|
if err := DB.Save(&harumph).Error; err != nil {
|
||||||
|
t.Errorf("Failed to update date, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Post struct {
|
type Post struct {
|
||||||
|
@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
|
||||||
t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
|
t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {
|
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {
|
||||||
|
@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreloadWithDiffModel(t *testing.T) {
|
||||||
|
user := *GetUser("preload_with_diff_model", Config{Account: true})
|
||||||
|
|
||||||
|
if err := DB.Create(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Something string
|
||||||
|
User
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select(
|
||||||
|
"users.*, 'yo' as something").First(&result, "name = ?", user.Name)
|
||||||
|
|
||||||
|
CheckUser(t, user, result.User)
|
||||||
|
}
|
||||||
|
@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindInBatchesWithOffsetLimit(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
*GetUser("find_in_batches_with_offset_limit", Config{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sub, results []User
|
||||||
|
lastBatch int
|
||||||
|
)
|
||||||
|
|
||||||
|
// offset limit
|
||||||
|
if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
results = append(results, sub...)
|
||||||
|
lastBatch = batch
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 5 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
if lastBatch != 3 {
|
||||||
|
t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetUsers := users[3:8]
|
||||||
|
for i := 0; i < len(targetUsers); i++ {
|
||||||
|
AssertEqual(t, results[i], targetUsers[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub1 []User
|
||||||
|
// limit < batchSize
|
||||||
|
if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 5 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub2 []User
|
||||||
|
// only offset
|
||||||
|
if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 7 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sub3 []User
|
||||||
|
if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error {
|
||||||
|
return nil
|
||||||
|
}); result.Error != nil || result.RowsAffected != 4 {
|
||||||
|
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFindInBatchesWithError(t *testing.T) {
|
func TestFindInBatchesWithError(t *testing.T) {
|
||||||
if name := DB.Dialector.Name(); name == "sqlserver" {
|
if name := DB.Dialector.Name(); name == "sqlserver" {
|
||||||
t.Skip("skip sqlserver due to it will raise data race for invalid sql")
|
t.Skip("skip sqlserver due to it will raise data race for invalid sql")
|
||||||
@ -512,7 +574,13 @@ func TestNotWithAllFields(t *testing.T) {
|
|||||||
func TestOr(t *testing.T) {
|
func TestOr(t *testing.T) {
|
||||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{})
|
var count int64
|
||||||
|
result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count)
|
||||||
|
if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) {
|
||||||
|
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{})
|
||||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) {
|
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) {
|
||||||
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
@ -577,7 +645,9 @@ func TestPluck(t *testing.T) {
|
|||||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil {
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil {
|
||||||
t.Errorf("got error when pluck name: %v", err)
|
t.Errorf("got error when pluck name: %v", err)
|
||||||
}
|
}
|
||||||
AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2)))
|
|
||||||
|
sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] })
|
||||||
|
AssertEqual(t, names, names2)
|
||||||
|
|
||||||
var ids []int
|
var ids []int
|
||||||
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil {
|
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil {
|
||||||
@ -1152,3 +1222,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
|
|||||||
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DoubleInt64 struct {
|
||||||
|
data int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DoubleInt64) Scan(val interface{}) error {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case int64:
|
||||||
|
t.data = v * 2
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/go-gorm/gorm/issues/5091
|
||||||
|
func TestQueryScannerWithSingleColumn(t *testing.T) {
|
||||||
|
user := User{Name: "scanner_raw_1", Age: 10}
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
var result1 DoubleInt64
|
||||||
|
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
|
||||||
|
"age", &result1).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result1.data, 20)
|
||||||
|
|
||||||
|
var result2 DoubleInt64
|
||||||
|
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
|
||||||
|
"age").Scan(&result2).Error; err != nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result2.data, 20)
|
||||||
|
}
|
||||||
|
@ -10,6 +10,11 @@ import (
|
|||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PersonAddressInfo struct {
|
||||||
|
Person *Person `gorm:"embedded"`
|
||||||
|
Address *Address `gorm:"embedded"`
|
||||||
|
}
|
||||||
|
|
||||||
func TestScan(t *testing.T) {
|
func TestScan(t *testing.T) {
|
||||||
user1 := User{Name: "ScanUser1", Age: 1}
|
user1 := User{Name: "ScanUser1", Age: 1}
|
||||||
user2 := User{Name: "ScanUser2", Age: 10}
|
user2 := User{Name: "ScanUser2", Age: 10}
|
||||||
@ -156,3 +161,57 @@ func TestScanRows(t *testing.T) {
|
|||||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanToEmbedded(t *testing.T) {
|
||||||
|
person1 := Person{Name: "person 1"}
|
||||||
|
person2 := Person{Name: "person 2"}
|
||||||
|
DB.Save(&person1).Save(&person2)
|
||||||
|
|
||||||
|
address1 := Address{Name: "address 1"}
|
||||||
|
address2 := Address{Name: "address 2"}
|
||||||
|
DB.Save(&address1).Save(&address2)
|
||||||
|
|
||||||
|
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)})
|
||||||
|
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)})
|
||||||
|
DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)})
|
||||||
|
|
||||||
|
var personAddressInfoList []*PersonAddressInfo
|
||||||
|
if err := DB.Select("people.*, addresses.*").
|
||||||
|
Table("people").
|
||||||
|
Joins("inner join person_addresses on people.id = person_addresses.person_id").
|
||||||
|
Joins("inner join addresses on person_addresses.address_id = addresses.id").
|
||||||
|
Find(&personAddressInfoList).Error; err != nil {
|
||||||
|
t.Errorf("Failed to run join query, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
personMatched := false
|
||||||
|
addressMatched := false
|
||||||
|
|
||||||
|
for _, info := range personAddressInfoList {
|
||||||
|
if info.Person == nil {
|
||||||
|
t.Fatalf("Failed, expected not nil, got person nil")
|
||||||
|
}
|
||||||
|
if info.Address == nil {
|
||||||
|
t.Fatalf("Failed, expected not nil, got address nil")
|
||||||
|
}
|
||||||
|
if info.Person.ID == person1.ID {
|
||||||
|
personMatched = true
|
||||||
|
if info.Person.Name != person1.Name {
|
||||||
|
t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if info.Address.ID == address1.ID {
|
||||||
|
addressMatched = true
|
||||||
|
if info.Address.Name != address1.Name {
|
||||||
|
t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !personMatched {
|
||||||
|
t.Errorf("Failed, no person matched")
|
||||||
|
}
|
||||||
|
if !addressMatched {
|
||||||
|
t.Errorf("Failed, no address matched")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
138
tests/serializer_test.go
Normal file
138
tests/serializer_test.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SerializerStruct struct {
|
||||||
|
gorm.Model
|
||||||
|
Name []byte `gorm:"json"`
|
||||||
|
Roles Roles `gorm:"serializer:json"`
|
||||||
|
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||||
|
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||||
|
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||||
|
UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||||
|
EncryptedString EncryptedString
|
||||||
|
}
|
||||||
|
|
||||||
|
type Roles []string
|
||||||
|
|
||||||
|
type Job struct {
|
||||||
|
Title string
|
||||||
|
Number int
|
||||||
|
Location string
|
||||||
|
IsIntern bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptedString string
|
||||||
|
|
||||||
|
func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||||
|
switch value := dbValue.(type) {
|
||||||
|
case []byte:
|
||||||
|
*es = EncryptedString(bytes.TrimPrefix(value, []byte("hello")))
|
||||||
|
case string:
|
||||||
|
*es = EncryptedString(strings.TrimPrefix(value, "hello"))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported data %#v", dbValue)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||||
|
return "hello" + string(es), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializer(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&SerializerStruct{})
|
||||||
|
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
||||||
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
updatedAt := createdAt.Unix()
|
||||||
|
|
||||||
|
data := SerializerStruct{
|
||||||
|
Name: []byte("jinzhu"),
|
||||||
|
Roles: []string{"r1", "r2"},
|
||||||
|
Contracts: map[string]interface{}{"name": "jinzhu", "age": 10},
|
||||||
|
EncryptedString: EncryptedString("pass"),
|
||||||
|
CreatedTime: createdAt.Unix(),
|
||||||
|
UpdatedTime: &updatedAt,
|
||||||
|
JobInfo: Job{
|
||||||
|
Title: "programmer",
|
||||||
|
Number: 9920,
|
||||||
|
Location: "Kenmawr",
|
||||||
|
IsIntern: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result SerializerStruct
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result, data)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||||
|
DB.Migrator().DropTable(&SerializerStruct{})
|
||||||
|
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
||||||
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
data := SerializerStruct{
|
||||||
|
Name: []byte("ag9920"),
|
||||||
|
Roles: []string{"r1", "r2"},
|
||||||
|
Contracts: map[string]interface{}{"name": "jing1", "age": 11},
|
||||||
|
EncryptedString: EncryptedString("pass"),
|
||||||
|
CreatedTime: createdAt.Unix(),
|
||||||
|
JobInfo: Job{
|
||||||
|
Title: "programmer",
|
||||||
|
Number: 9920,
|
||||||
|
Location: "Shadyside",
|
||||||
|
IsIntern: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// first time insert record
|
||||||
|
out := SerializerStruct{}
|
||||||
|
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
||||||
|
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result SerializerStruct
|
||||||
|
if err := DB.First(&result, out.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
AssertEqual(t, result, out)
|
||||||
|
|
||||||
|
//update record
|
||||||
|
data.Roles = append(data.Roles, "r3")
|
||||||
|
data.JobInfo.Location = "Gates Hillman Complex"
|
||||||
|
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
|
||||||
|
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
|
||||||
|
}
|
||||||
|
if err := DB.First(&result, out.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result.Roles, data.Roles)
|
||||||
|
AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location)
|
||||||
|
}
|
@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ageInt int8
|
||||||
|
|
||||||
|
func (ageInt) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageBool bool
|
||||||
|
|
||||||
|
func (ageBool) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageUint64 uint64
|
||||||
|
|
||||||
|
func (ageUint64) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ageFloat float64
|
||||||
|
|
||||||
|
func (ageFloat) String() string {
|
||||||
|
return "age"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExplainSQL(t *testing.T) {
|
||||||
|
user := *GetUser("explain-sql", Config{})
|
||||||
|
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
|
|
||||||
|
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
|
||||||
|
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
|
||||||
|
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
|
||||||
|
t.Errorf("Failed to generate sql, got %v", sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGroupConditions(t *testing.T) {
|
func TestGroupConditions(t *testing.T) {
|
||||||
type Pizza struct {
|
type Pizza struct {
|
||||||
ID uint
|
ID uint
|
||||||
@ -190,6 +243,21 @@ func TestGroupConditions(t *testing.T) {
|
|||||||
if !strings.HasSuffix(result, expects) {
|
if !strings.HasSuffix(result, expects) {
|
||||||
t.Errorf("expects: %v, got %v", expects, result)
|
t.Errorf("expects: %v, got %v", expects, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmt2 := dryRunDB.Where(
|
||||||
|
DB.Scopes(NameIn1And2),
|
||||||
|
).Or(
|
||||||
|
DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"),
|
||||||
|
).Find(&Pizza{}).Statement
|
||||||
|
|
||||||
|
execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement
|
||||||
|
|
||||||
|
result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...)
|
||||||
|
expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...)
|
||||||
|
|
||||||
|
if !strings.HasSuffix(result2, expects2) {
|
||||||
|
t.Errorf("expects: %v, got %v", expects2, result2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCombineStringConditions(t *testing.T) {
|
func TestCombineStringConditions(t *testing.T) {
|
||||||
@ -307,7 +375,7 @@ func TestToSQL(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql)
|
assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql)
|
||||||
|
|
||||||
// after model chagned
|
// after model changed
|
||||||
if DB.Statement.DryRun || DB.DryRun {
|
if DB.Statement.DryRun || DB.DryRun {
|
||||||
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
||||||
}
|
}
|
||||||
@ -373,13 +441,13 @@ func TestToSQL(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql)
|
assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql)
|
||||||
|
|
||||||
// after model chagned
|
// after model changed
|
||||||
if DB.Statement.DryRun || DB.DryRun {
|
if DB.Statement.DryRun || DB.DryRun {
|
||||||
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals.
|
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
|
||||||
func assertEqualSQL(t *testing.T, expected string, actually string) {
|
func assertEqualSQL(t *testing.T, expected string, actually string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@ -387,7 +455,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
|
|||||||
expected = replaceQuoteInSQL(expected)
|
expected = replaceQuoteInSQL(expected)
|
||||||
actually = replaceQuoteInSQL(actually)
|
actually = replaceQuoteInSQL(actually)
|
||||||
|
|
||||||
// ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update.
|
// ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update.
|
||||||
updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`)
|
updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`)
|
||||||
actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`)
|
actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`)
|
||||||
expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`)
|
expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`)
|
||||||
@ -407,16 +475,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
|
|||||||
|
|
||||||
func replaceQuoteInSQL(sql string) string {
|
func replaceQuoteInSQL(sql string) string {
|
||||||
// convert single quote into double quote
|
// convert single quote into double quote
|
||||||
sql = strings.Replace(sql, `'`, `"`, -1)
|
sql = strings.ReplaceAll(sql, `'`, `"`)
|
||||||
|
|
||||||
// convert dialect speical quote into double quote
|
// convert dialect special quote into double quote
|
||||||
switch DB.Dialector.Name() {
|
switch DB.Dialector.Name() {
|
||||||
case "postgres":
|
case "postgres":
|
||||||
sql = strings.Replace(sql, `"`, `"`, -1)
|
sql = strings.ReplaceAll(sql, `"`, `"`)
|
||||||
case "mysql", "sqlite":
|
case "mysql", "sqlite":
|
||||||
sql = strings.Replace(sql, "`", `"`, -1)
|
sql = strings.ReplaceAll(sql, "`", `"`)
|
||||||
case "sqlserver":
|
case "sqlserver":
|
||||||
sql = strings.Replace(sql, `'`, `"`, -1)
|
sql = strings.ReplaceAll(sql, `'`, `"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
@ -15,6 +15,25 @@ then
|
|||||||
cd ..
|
cd ..
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# SqlServer for Mac M1
|
||||||
|
if [[ -z $GITHUB_ACTION ]]; then
|
||||||
|
if [ -d tests ]
|
||||||
|
then
|
||||||
|
cd tests
|
||||||
|
if [[ $(uname -a) == *" arm64" ]]; then
|
||||||
|
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
|
||||||
|
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
|
||||||
|
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
|
||||||
|
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
|
||||||
|
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
|
||||||
|
else
|
||||||
|
docker-compose start
|
||||||
|
fi
|
||||||
|
cd ..
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
for dialect in "${dialects[@]}" ; do
|
for dialect in "${dialects[@]}" ; do
|
||||||
if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ]
|
if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ]
|
||||||
then
|
then
|
||||||
|
@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
PreferSimpleProtocol: true,
|
PreferSimpleProtocol: true,
|
||||||
}), &gorm.Config{})
|
}), &gorm.Config{})
|
||||||
case "sqlserver":
|
case "sqlserver":
|
||||||
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
|
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
|
||||||
|
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
|
||||||
// CREATE DATABASE gorm;
|
// CREATE DATABASE gorm;
|
||||||
// USE gorm;
|
// GO
|
||||||
|
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
|
||||||
// CREATE USER gorm FROM LOGIN gorm;
|
// CREATE USER gorm FROM LOGIN gorm;
|
||||||
// sp_changedbowner 'gorm';
|
// ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];
|
||||||
// npm install -g sql-cli
|
// GO
|
||||||
// mssql -u gorm -p LoremIpsum86 -d gorm -o 9930
|
|
||||||
log.Println("testing sqlserver...")
|
log.Println("testing sqlserver...")
|
||||||
if dbDSN == "" {
|
if dbDSN == "" {
|
||||||
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||||
@ -94,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
|||||||
|
|
||||||
func RunMigrations() {
|
func RunMigrations() {
|
||||||
var err error
|
var err error
|
||||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
|
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||||
|
|
||||||
|
@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
|
|||||||
|
|
||||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||||
stmt := dryDB.Save(&user).Statement
|
stmt := dryDB.Save(&user).Statement
|
||||||
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
||||||
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
|
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) {
|
|||||||
tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user)
|
tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user)
|
||||||
|
|
||||||
if err := tx.Error; err != nil {
|
if err := tx.Error; err != nil {
|
||||||
t.Fatalf("failed to update user,missing where condtion,err=%+v", err)
|
t.Fatalf("failed to update user,missing where condition,err=%+v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) {
|
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) {
|
||||||
|
@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
|
|||||||
shiftDelimiter = 0
|
shiftDelimiter = 0
|
||||||
underQuoted = false
|
underQuoted = false
|
||||||
continuousBacktick = 0
|
continuousBacktick = 0
|
||||||
writer.WriteString("`")
|
writer.WriteByte('`')
|
||||||
}
|
}
|
||||||
writer.WriteByte(v)
|
writer.WriteByte(v)
|
||||||
continue
|
continue
|
||||||
@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
|
|||||||
if continuousBacktick > 0 && !selfQuoted {
|
if continuousBacktick > 0 && !selfQuoted {
|
||||||
writer.WriteString("``")
|
writer.WriteString("``")
|
||||||
}
|
}
|
||||||
writer.WriteString("`")
|
writer.WriteByte('`')
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DummyDialector) Explain(sql string, vars ...interface{}) string {
|
func (DummyDialector) Explain(sql string, vars ...interface{}) string {
|
||||||
|
@ -80,3 +80,17 @@ type Order struct {
|
|||||||
Coupon *Coupon
|
Coupon *Coupon
|
||||||
CouponID string
|
CouponID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Parent struct {
|
||||||
|
gorm.Model
|
||||||
|
FavChildID uint
|
||||||
|
FavChild *Child
|
||||||
|
Children []*Child
|
||||||
|
}
|
||||||
|
|
||||||
|
type Child struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Parent
|
||||||
|
}
|
||||||
|
@ -83,6 +83,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if reflect.ValueOf(got).Kind() == reflect.Struct {
|
if reflect.ValueOf(got).Kind() == reflect.Struct {
|
||||||
|
if reflect.ValueOf(expect).Kind() == reflect.Struct {
|
||||||
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
|
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
|
||||||
exported := false
|
exported := false
|
||||||
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
|
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
|
||||||
@ -100,6 +101,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) {
|
||||||
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface()
|
||||||
@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||||||
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
|
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
|
||||||
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
|
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
|
||||||
isEqual()
|
isEqual()
|
||||||
|
} else {
|
||||||
|
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool {
|
|||||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckTruth(val interface{}) bool {
|
// CheckTruth check string true or not
|
||||||
if v, ok := val.(bool); ok {
|
func CheckTruth(vals ...string) bool {
|
||||||
return v
|
for _, val := range vals {
|
||||||
|
if val != "" && !strings.EqualFold(val, "false") {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := val.(string); ok {
|
|
||||||
v = strings.ToLower(v)
|
|
||||||
return v != "false"
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
return !reflect.ValueOf(val).IsZero()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToStringKey(values ...interface{}) string {
|
func ToStringKey(values ...interface{}) string {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user