Merge branch 'go-gorm:master' into master

This commit is contained in:
piyongcai 2022-05-09 16:15:34 +08:00 committed by GitHub
commit e0030749b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 2728 additions and 761 deletions

View File

@ -3,20 +3,26 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v5
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
days-before-stale: 0 days-before-stale: 0
days-before-close: 2 days-before-close: 30
remove-stale-when-updated: true remove-stale-when-updated: true
only-labels: "type:invalid question" only-labels: "type:invalid question"

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests name: Label issues and pull requests
steps: steps:
- name: check out - name: check out
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: labeler - name: labeler
uses: jinzhu/super-labeler-action@develop uses: jinzhu/super-labeler-action@develop

View File

@ -3,19 +3,25 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v5
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
days-before-stale: 0 days-before-stale: 0
days-before-close: 2 days-before-close: 30
remove-stale-when-updated: true remove-stale-when-updated: true
only-labels: "type:missing reproduction steps" only-labels: "type:missing reproduction steps"

View File

@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: golangci-lint - name: golangci-lint
uses: reviewdog/action-golangci-lint@v2 uses: reviewdog/action-golangci-lint@v2

View File

@ -3,19 +3,25 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v4 uses: actions/stale@v5
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
days-before-stale: 60 days-before-stale: 360
days-before-close: 30 days-before-close: 180
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale' stale-pr-label: 'status:stale'

View File

@ -8,38 +8,41 @@ on:
branches-ignore: branches-ignore:
- 'gh-pages' - 'gh-pages'
permissions:
contents: read
jobs: jobs:
# Label of the container job # Label of the container job
sqlite: sqlite:
strategy: strategy:
matrix: matrix:
go: ['1.17', '1.16'] go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in windows OS platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GORM_DIALECT=sqlite ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql: mysql:
strategy: strategy:
matrix: matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.17', '1.16'] go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -62,28 +65,28 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres: postgres:
strategy: strategy:
matrix: matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.17', '1.16'] go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in macOS and Windows platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -106,26 +109,26 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver: sqlserver:
strategy: strategy:
matrix: matrix:
go: ['1.17', '1.16'] go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run test in macOS and windows platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -149,18 +152,18 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v2 uses: actions/setup-go@v3
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v2 uses: actions/checkout@v3
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v2 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor

View File

@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
## Contributing ## Contributing

View File

@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
} }
case reflect.Struct: case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
} }
for _, ref := range rel.References { for _, ref := range rel.References {
@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error {
primaryFields []*schema.Field primaryFields []*schema.Field
foreignKeys []string foreignKeys []string
updateMap = map[string]interface{}{} updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values}) tx.Not(clause.IN{Column: column, Values: values})
} }
@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
} }
@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values}) tx.Where(clause.IN{Column: column, Values: values})
} else { } else {
return ErrPrimaryKeyRequired return ErrPrimaryKeyRequired
} }
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
} }
@ -186,11 +186,14 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.BelongsTo: case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -198,11 +201,14 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -228,11 +234,14 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -241,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
// clean up deleted values's foreign key // clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) { cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero { if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() { switch fieldValue.Kind() {
@ -253,7 +262,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ { for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -261,23 +270,23 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
association.Error = rel.Field.Set(data, validFieldValues.Interface()) association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct: case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue) primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break break
} }
if rel.JoinTable == nil { if rel.JoinTable == nil {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" { if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else { } else {
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -329,14 +338,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() { switch rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() > 0 { if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
} }
} }
case reflect.Struct: case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -344,7 +353,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
if clear { if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
} }
@ -373,7 +382,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
if association.Error == nil { if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
} }
} }
} }
@ -421,7 +430,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -429,7 +438,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if association.Relationship.JoinTable == nil { if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -453,12 +462,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct: case reflect.Struct:
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil { if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -475,7 +484,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
for _, assignBack := range assignBacks { for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 { if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else { } else {
@ -486,14 +495,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
func (association *Association) buildCondition() *DB { func (association *Association) buildCondition() *DB {
var ( var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }

View File

@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.Slice(cs, func(i, j int) bool { sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*" if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
}) })
for _, c := range cs { for _, c := range cs {

View File

@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(obj, pv)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv dest[ref.ForeignKey.DBName] = pv
@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
break break
} }
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
objs = append(objs, obj) objs = append(objs, obj)
if isPtr { if isPtr {
elems = reflect.Append(elems, rv) elems = reflect.Append(elems, rv)
@ -69,20 +69,20 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }
} }
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv) setupReferences(db.Statement.ReflectValue, rv)
} }
} }
@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct { if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(obj) rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(rv, fv)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
} }
} }
@ -146,11 +146,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr { if f.Kind() != reflect.Ptr {
f = f.Addr() f = f.Addr()
} }
@ -158,15 +158,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns := make([]string, 0, len(rel.References)) assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
} }
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
} }
} }
} }
@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{} identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v) pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
ref.ForeignKey.Set(elem, pv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
} }
} }
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields { for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(elem); !ok { if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv) relPrimaryValues = append(relPrimaryValues, pfv)
} }
} }
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
} }
} }
@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joinValue := reflect.New(rel.JoinTable.ModelType) joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
ref.ForeignKey.Set(joinValue, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else { } else {
fv, _ := ref.PrimaryKey.ValueOf(elem) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
ref.ForeignKey.Set(joinValue, fv) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} }
} }
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v)) f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length // optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 { if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v { if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) saveAssociations(db, rel, elems, selectColumns, restricted, nil)
} }
for i := 0; i < elemLen; i++ { for i := 0; i < elemLen; i++ {
@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames { for _, dbName := range s.PrimaryFieldDBNames {
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
return return
} }
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var ( var (
selects, omits []string selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "." refName = rel.Name + "."
values = rValues.Interface()
) )
for name, ok := range selectColumns { for name, ok := range selectColumns {
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
return db.AddError(tx.Create(values).Error) return db.AddError(tx.Create(values).Error)
} }
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) {
} }
} }
// Create create hook
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING") supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
@ -82,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
) )
if db.AddError(err) == nil { if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode) gorm.Scan(rows, db, mode)
db.AddError(rows.Close())
} }
return return
@ -117,9 +121,9 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero { if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
} }
} }
@ -130,38 +134,39 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
} }
} }
} }
case reflect.Struct: case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero { if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
} }
} }
} }
} }
} }
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate { if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok { if i, ok := value.(AfterCreateInterface); ok {
called = true called = true
db.AddError(i.AfterCreate(tx)) db.AddError(i.AfterCreate(tx))
} }
} }
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called return called
}) })
} }
@ -201,13 +206,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len() rValLen := stmt.ReflectValue.Len()
stmt.SQL.Grow(rValLen * 18)
values.Values = make([][]interface{}, rValLen)
if rValLen == 0 { if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice) stmt.AddError(gorm.ErrEmptySlice)
return return
} }
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ { for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
@ -219,23 +226,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface) stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime) stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(rv, curTime) stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(rv) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(rv); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 { if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
} }
@ -259,23 +266,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(stmt.ReflectValue, curTime) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue) values.Values[0] = append(values.Values[0], rvOfvalue)
} }

View File

@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
switch rel.Type { switch rel.Type {
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false withoutConditions := false
@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
} }
} }
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values}) queryConds = append(queryConds, clause.IN{Column: column, Values: values})
@ -118,12 +118,18 @@ func Delete(config *Config) func(db *gorm.DB) {
return return
} }
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100) db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -131,7 +137,7 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
} }
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { checkMissingWhereConditions(db)
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning) ok, mode := hasReturning(db, supportReturning)

View File

@ -1,6 +1,7 @@
package callbacks package callbacks
import ( import (
"reflect"
"sort" "sort"
"gorm.io/gorm" "gorm.io/gorm"
@ -104,3 +105,48 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
} }
return false, 0 return false, 0
} }
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

View File

@ -10,10 +10,9 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = db.Statement.ReflectValue reflectValue = tx.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string relForeignKeys []string
relForeignFields []*schema.Field relForeignFields []*schema.Field
foreignFields []*schema.Field foreignFields []*schema.Field
@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{} inlineConds []interface{}
) )
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil { if rel.JoinTable != nil {
var ( var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 { if len(joinForeignValues) == 0 {
return return nil
} }
joinResults := rel.JoinTable.MakeSlice().Elem() joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map // convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields)) fieldValues := make([]interface{}, len(joinForeignFields))
@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i) joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinIndexValue) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 { if len(foreignValues) == 0 {
return return nil
} }
} }
@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
} }
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct: case reflect.Struct:
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default: default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default: default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
} }
} }
} }
@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem) fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
} }
datas, ok := identityMap[utils.ToStringKey(fieldValues...)] datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok { if !ok {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
elem.Interface()))
continue
} }
for _, data := range datas { for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(data) reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
} }
@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue) reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(data, elem.Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else { } else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
} }
} }
} }
} }
return tx.Error
} }

View File

@ -20,8 +20,10 @@ func Query(db *gorm.DB) {
db.AddError(err) db.AddError(err)
return return
} }
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0) gorm.Scan(rows, db, 0)
db.AddError(rows.Close())
} }
} }
} }
@ -40,7 +42,7 @@ func BuildQuerySQL(db *gorm.DB) {
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields { for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
} }
} }
@ -94,13 +96,13 @@ func BuildQuerySQL(db *gorm.DB) {
} }
// inline joins // inline joins
joins := []clause.Join{} fromClause := clause.From{}
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
joins = fromClause.Joins fromClause = v
} }
if len(db.Statement.Joins) != 0 || len(joins) != 0 { if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames { for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
@ -109,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) {
for _, join := range db.Statement.Joins { for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil { if db.Statement.Schema == nil {
joins = append(joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
@ -145,35 +147,49 @@ func BuildQuerySQL(db *gorm.DB) {
} }
} }
if join.On != nil { {
onStmt := gorm.Statement{Table: tableAliasName, DB: db} onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
join.On.Build(&onStmt) for _, c := range relation.FieldSchema.QueryClauses {
onSQL := onStmt.SQL.String() onStmt.AddClause(c)
vars := onStmt.Vars
for idx, v := range onStmt.Vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
} }
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
} }
joins = append(joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Type: clause.LeftJoin, Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs}, ON: clause.Where{Exprs: exprs},
}) })
} else { } else {
joins = append(joins, clause.Join{ fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
} }
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil db.Statement.Joins = nil
db.Statement.AddClause(clause.From{Joins: joins})
} else { } else {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
} }
@ -186,6 +202,11 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
preloadMap := map[string]map[string][]interface{}{} preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".") preloadFields := strings.Split(name, ".")
@ -218,9 +239,20 @@ func Preload(db *gorm.DB) {
} }
sort.Strings(preloadNames) sort.Strings(preloadNames)
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
for _, name := range preloadNames { for _, name := range preloadNames {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else { } else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
} }

View File

@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
} }
} }
} }
@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
} }
} }
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) {
} }
} }
// Update update hook
func Update(config *Config) func(db *gorm.DB) { func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
@ -59,6 +61,12 @@ func Update(config *Config) func(db *gorm.DB) {
return return
} }
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
@ -68,22 +76,10 @@ func Update(config *Config) func(db *gorm.DB) {
return return
} }
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
} }
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { checkMissingWhereConditions(db)
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok { if ok, mode := hasReturning(db, supportReturning); ok {
@ -105,9 +101,17 @@ func Update(config *Config) func(db *gorm.DB) {
} }
} }
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
called = true called = true
@ -115,12 +119,6 @@ func AfterUpdate(db *gorm.DB) {
} }
} }
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
return called return called
}) })
} }
@ -137,13 +135,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value) field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
} }
} }
case reflect.Struct: case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() { if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value) field.Set(stmt.Context, stmt.ReflectValue, value)
} }
} }
default: default:
@ -165,7 +163,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool var notZero bool
for idx, field := range stmt.Schema.PrimaryFields { for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero notZero = notZero || !isZero
} }
@ -178,7 +176,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
@ -232,10 +230,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.GORMDataType == schema.Time { } else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} }
} }
} }
@ -258,16 +256,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field := updatingSchema.LookUpField(dbName); field != nil { if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue) value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano() value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6 value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time { } else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix() value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
} }
isZero = false isZero = false
} }
@ -278,7 +276,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
} }
} }
} else { } else {
if value, isZero := field.ValueOf(updatingValue); !isZero { if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }

View File

@ -0,0 +1,36 @@
package callbacks
import (
"reflect"
"testing"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

View File

@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
} else if tables := strings.Split(name, "."); len(tables) == 2 { } else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1] tx.Statement.Table = tables[1]
} else { } else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
} }
return return
} }
@ -90,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
} }
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string: case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 { if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{ tx.Statement.AddClause(clause.Select{
@ -120,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
} }
} }
delete(tx.Statement.Clauses, "SELECT") if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
} }
default: default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))

View File

@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) {
} }
if limit.Offset > 0 { if limit.Offset > 0 {
if limit.Limit > 0 { if limit.Limit > 0 {
builder.WriteString(" ") builder.WriteByte(' ')
} }
builder.WriteString("OFFSET ") builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset)) builder.WriteString(strconv.Itoa(limit.Offset))

View File

@ -43,6 +43,23 @@ func TestSelect(t *testing.T) {
}, clause.From{}}, }, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
}, },
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -4,6 +4,11 @@ import (
"strings" "strings"
) )
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) {
} }
} }
buildExprs(where.Exprs, builder, " AND ") buildExprs(where.Exprs, builder, AndWithSpace)
} }
func buildExprs(exprs []Expression, builder Builder, joinCond string) { func buildExprs(exprs []Expression, builder Builder, joinCond string) {
@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
for idx, expr := range exprs { for idx, expr := range exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ") builder.WriteString(OrWithSpace)
} else { } else {
builder.WriteString(joinCond) builder.WriteString(joinCond)
} }
@ -46,30 +51,30 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
case OrConditions: case OrConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
case AndConditions: case AndConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
case Expr: case Expr:
sql := strings.ToLower(v.SQL) sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr: case NamedExpr:
sql := strings.ToLower(v.SQL) sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
if wrapInParentheses { if wrapInParentheses {
builder.WriteString(`(`) builder.WriteByte('(')
expr.Build(builder) expr.Build(builder)
builder.WriteString(`)`) builder.WriteByte(')')
wrapInParentheses = false wrapInParentheses = false
} else { } else {
expr.Build(builder) expr.Build(builder)
@ -110,10 +115,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) { func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 { if len(and.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(and.Exprs, builder, " AND ") buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(and.Exprs, builder, " AND ") buildExprs(and.Exprs, builder, AndWithSpace)
} }
} }
@ -131,10 +136,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) { func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 { if len(or.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(or.Exprs, builder, " OR ") buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(or.Exprs, builder, " OR ") buildExprs(or.Exprs, builder, OrWithSpace)
} }
} }
@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) {
for idx, c := range not.Exprs { for idx, c := range not.Exprs {
if idx > 0 { if idx > 0 {
builder.WriteString(" AND ") builder.WriteString(AndWithSpace)
} }
if negationBuilder, ok := c.(NegationExpressionBuilder); ok { if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteString("NOT ") builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr) e, wrapInParentheses := c.(Expr)
if wrapInParentheses { if wrapInParentheses {
sql := strings.ToLower(e.SQL) sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(') builder.WriteByte('(')
} }
} }

View File

@ -66,6 +66,45 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
[]interface{}{18, "jinzhu"}, []interface{}{18, "jinzhu"},
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -39,4 +39,6 @@ var (
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length // ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
) )

View File

@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Dest = value tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
@ -83,7 +87,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
case reflect.Struct: case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields { for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero { if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
} }
@ -101,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface() result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
return tx.Create(value) return tx.Create(value)
} }
} }
@ -177,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch int batch int
) )
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
totalSize = limit.Limit
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for { for {
result := queryDB.Limit(batchSize).Find(dest) result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected rowsAffected += result.RowsAffected
@ -192,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break // Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest)) resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil { if result.Statement.Schema.PrioritizedPrimaryField == nil {
@ -199,7 +227,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
@ -207,7 +235,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return tx return tx
} }
func (tx *DB) assignInterfacesToValue(values ...interface{}) { func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values { for _, value := range values {
switch v := value.(type) { switch v := value.(type) {
case []clause.Expression: case []clause.Expression:
@ -215,40 +243,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
if eq, ok := expr.(clause.Eq); ok { if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil { if field := db.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
} }
case clause.Column: case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
} }
} }
} else if andCond, ok := expr.(clause.AndConditions); ok { } else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs) db.assignInterfacesToValue(andCond.Exprs)
} }
} }
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs) db.assignInterfacesToValue(exprs)
} }
default: default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
for _, f := range s.Fields { for _, f := range s.Fields {
if f.Readable { if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero { if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v)) db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
} }
} }
} }
} }
} }
} else if len(values) > 0 { } else if len(values) > 0 {
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs) db.assignInterfacesToValue(exprs)
} }
return return
} }
@ -256,12 +284,13 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
@ -281,26 +310,28 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if tx = queryTx.Find(dest, conds...); tx.Error == nil { if result := queryTx.Find(dest, conds...); result.Error == nil {
if tx.RowsAffected == 0 { if result.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs) result.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.attrs) > 0 { if len(db.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...) result.assignInterfacesToValue(db.Statement.attrs...)
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(tx.Statement.assigns) > 0 { if len(db.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...) result.assignInterfacesToValue(db.Statement.assigns...)
} }
return tx.Create(dest) return tx.Create(dest)
@ -320,6 +351,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
} }
return tx.Model(dest).Updates(assigns) return tx.Model(dest).Updates(assigns)
} else {
tx.Error = result.Error
} }
} }
return tx return tx
@ -585,7 +618,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var ( var (
// clone statement // clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions opt *sql.TxOptions
err error err error
) )
@ -594,11 +627,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0] opt = opts[0]
} }
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else { default:
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }

17
gorm.go
View File

@ -59,6 +59,7 @@ type Config struct {
cacheStore *sync.Map cacheStore *sync.Map
} }
// Apply update config to new config
func (c *Config) Apply(config *Config) error { func (c *Config) Apply(config *Config) error {
if config != c { if config != c {
*config = *c *config = *c
@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error {
return nil return nil
} }
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error { func (c *Config) AfterInitialize(db *DB) error {
if db != nil { if db != nil {
for _, plugin := range c.Plugins { for _, plugin := range c.Plugins {
@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error {
return nil return nil
} }
// Option gorm option interface
type Option interface { type Option interface {
Apply(*Config) error Apply(*Config) error
AfterInitialize(*DB) error AfterInitialize(*DB) error
@ -96,6 +99,7 @@ type Session struct {
DryRun bool DryRun bool
PrepareStmt bool PrepareStmt bool
NewDB bool NewDB bool
Initialized bool
SkipHooks bool SkipHooks bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
DisableNestedTransaction bool DisableNestedTransaction bool
@ -120,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
for _, opt := range opts { for _, opt := range opts {
if opt != nil { if opt != nil {
if err := opt.Apply(config); err != nil { if applyErr := opt.Apply(config); applyErr != nil {
return nil, err return nil, applyErr
} }
defer func(opt Option) { defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil { if errr := opt.AfterInitialize(db); errr != nil {
@ -282,6 +286,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc tx.Config.NowFunc = config.NowFunc
} }
if config.Initialized {
tx = tx.getInstance()
}
return tx return tx
} }
@ -376,10 +384,12 @@ func (db *DB) getInstance() *DB {
return db return db
} }
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr { func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args} return clause.Expr{SQL: expr, Vars: args}
} }
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var ( var (
tx = db.getInstance() tx = db.getInstance()
@ -430,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
return nil return nil
} }
// Use use plugin
func (db *DB) Use(plugin Plugin) error { func (db *DB) Use(plugin Plugin) error {
name := plugin.Name() name := plugin.Name()
if _, ok := db.Plugins[name]; ok { if _, ok := db.Plugins[name]; ok {
@ -451,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{}) // .First(&User{})
// }) // })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true})) tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -40,24 +40,45 @@ type SavePointerDialectorInterface interface {
RollbackTo(tx *DB, name string) error RollbackTo(tx *DB, name string) error
} }
// TxBeginner tx beginner
type TxBeginner interface { type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface { type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
} }
// TxCommitter tx committer
type TxCommitter interface { type TxCommitter interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface // Valuer gorm valuer interface
type Valuer interface { type Valuer interface {
GormValue(context.Context, *DB) clause.Expr GormValue(context.Context, *DB) clause.Expr
} }
// GetDBConnector SQL db connector
type GetDBConnector interface { type GetDBConnector interface {
GetDBConn() (*sql.DB, error) GetDBConn() (*sql.DB, error)
} }
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
}

View File

@ -12,6 +12,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found") var ErrRecordNotFound = errors.New("record not found")
// Colors // Colors
@ -30,13 +31,17 @@ const (
YellowBold = "\033[33;1m" YellowBold = "\033[33;1m"
) )
// LogLevel // LogLevel log level
type LogLevel int type LogLevel int
const ( const (
// Silent silent log level
Silent LogLevel = iota + 1 Silent LogLevel = iota + 1
// Error error log level
Error Error
// Warn warn log level
Warn Warn
// Info info log level
Info Info
) )
@ -45,6 +50,7 @@ type Writer interface {
Printf(string, ...interface{}) Printf(string, ...interface{})
} }
// Config logger config
type Config struct { type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
@ -62,16 +68,20 @@ type Interface interface {
} }
var ( var (
// Discard Discard logger will print any log to ioutil.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond, SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn, LogLevel: Warn,
IgnoreRecordNotFoundError: false, IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
) )
// New initialize logger
func New(writer Writer, config Config) Interface { func New(writer Writer, config Config) Interface {
var ( var (
infoStr = "%s\n[info] " infoStr = "%s\n[info] "
@ -179,10 +189,12 @@ type traceRecorder struct {
Err error Err error
} }
// New new trace recorder
func (l traceRecorder) New() *traceRecorder { func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
} }
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin l.BeginAt = begin
l.SQL, l.RowsAffected = fc() l.SQL, l.RowsAffected = fc()

View File

@ -19,9 +19,9 @@ const (
nullStr = "NULL" nullStr = "NULL"
) )
func isPrintable(s []byte) bool { func isPrintable(s string) bool {
for _, r := range s { for _, r := range s {
if !unicode.IsPrint(rune(r)) { if !unicode.IsPrint(r) {
return false return false
} }
} }
@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int) var (
vars := make([]string, len(avars)) convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) { convertParams = func(v interface{}, idx int) {
switch v := v.(type) { switch v := v.(type) {
@ -64,14 +67,25 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
} }
case fmt.Stringer: case fmt.Stringer:
reflectValue := reflect.ValueOf(v) reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { switch reflectValue.Kind() {
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
} else { vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
vars[idx] = nullStr case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
} else {
vars[idx] = nullStr
}
} }
case []byte: case []byte:
if isPrintable(v) { if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
} else { } else {
vars[idx] = escaper + "<binary>" + escaper vars[idx] = escaper + "<binary>" + escaper
} }
@ -80,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case float64, float32: case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v) vars[idx] = fmt.Sprintf("%.6f", v)
case string: case string:
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
default: default:
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -97,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return return
} }
} }
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
} }
} }
} }

View File

@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
} }
func format(v []byte, escaper string) string { func format(v []byte, escaper string) string {
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
} }
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {

View File

@ -1,6 +1,8 @@
package gorm package gorm
import ( import (
"reflect"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -33,14 +35,23 @@ type ViewOption struct {
Query *DB Query *DB
} }
// ColumnType column type interface
type ColumnType interface { type ColumnType interface {
Name() string Name() string
DatabaseTypeName() string DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool) Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool) Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
} }
// Migrator migrator interface
type Migrator interface { type Migrator interface {
// AutoMigrate // AutoMigrate
AutoMigrate(dst ...interface{}) error AutoMigrate(dst ...interface{}) error

107
migrator/column_type.go Normal file
View File

@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

View File

@ -30,10 +30,12 @@ type Config struct {
gorm.Dialector gorm.Dialector
} }
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string GormDBDataType(*gorm.DB, *schema.Field) string
} }
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB} stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil { if m.DB.Statement != nil {
@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
return fc(stmt) return fc(stmt)
} }
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string { func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
return m.Dialector.DataTypeOf(field) return m.Dialector.DataTypeOf(field)
} }
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field) expr.SQL = m.DataTypeOf(field)
@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return return
} }
// AutoMigrate // AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) { for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
@ -95,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value) columnTypes, err := m.DB.Migrator().ColumnTypes(value)
if err != nil {
return err
}
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
@ -156,12 +163,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return nil return nil
} }
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error Scan(&tableList).Error
return return
} }
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
@ -252,6 +261,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
return nil return nil
} }
// DropTable drop table for values
func (m Migrator) DropTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false) values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- { for i := len(values) - 1; i >= 0; i-- {
@ -265,6 +275,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
return nil return nil
} }
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value interface{}) bool { func (m Migrator) HasTable(value interface{}) bool {
var count int64 var count int64
@ -276,6 +287,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0 return count > 0
} }
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{} var oldTable, newTable interface{}
if v, ok := oldName.(string); ok { if v, ok := oldName.(string); ok {
@ -303,12 +315,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
} }
func (m Migrator) AddColumn(value interface{}, field string) error { // AddColumn create `name` column for value
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
f := stmt.Schema.LookUpField(field) f := stmt.Schema.LookUpField(name)
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", name)
} }
if !f.IgnoreMigration { if !f.IgnoreMigration {
@ -322,6 +335,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
}) })
} }
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if field := stmt.Schema.LookUpField(name); field != nil {
@ -334,10 +348,11 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
}) })
} }
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if field := stmt.Schema.LookUpField(field); field != nil {
fileType := clause.Expr{SQL: m.DataTypeOf(field)} fileType := m.FullDataTypeOf(field)
return m.DB.Exec( return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
@ -348,6 +363,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
}) })
} }
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value interface{}, field string) bool { func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -366,6 +382,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0 return count > 0
} }
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if field := stmt.Schema.LookUpField(oldName); field != nil {
@ -383,6 +400,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
}) })
} }
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
// found, smart migrate // found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
@ -421,6 +439,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
} }
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn && !field.IgnoreMigration { if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name) return m.DB.Migrator().AlterColumn(value, field.Name)
} }
@ -448,7 +490,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
} }
for _, c := range rawColumnTypes { for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, c) columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
} }
return return
@ -457,10 +499,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr return columnTypes, execErr
} }
// CreateView create view
func (m Migrator) CreateView(name string, option gorm.ViewOption) error { func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
return gorm.ErrNotImplemented return gorm.ErrNotImplemented
} }
// DropView drop view
func (m Migrator) DropView(name string) error { func (m Migrator) DropView(name string) error {
return gorm.ErrNotImplemented return gorm.ErrNotImplemented
} }
@ -487,6 +531,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return return
} }
// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil { if stmt.Schema == nil {
return nil, nil, stmt.Table return nil, nil, stmt.Table
@ -531,6 +576,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
return nil, nil, stmt.Schema.Table return nil, nil, stmt.Schema.Table
} }
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error { func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -554,6 +600,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
}) })
} }
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -566,6 +613,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
}) })
} }
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value interface{}, name string) bool { func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -586,6 +634,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0 return count > 0
} }
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts { for _, opt := range opts {
str := stmt.Quote(opt.DBName) str := stmt.Quote(opt.DBName)
@ -607,10 +656,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return return
} }
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface { type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
} }
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -642,6 +693,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
}) })
} }
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -652,6 +704,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
}) })
} }
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value interface{}, name string) bool { func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -669,6 +722,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0 return count > 0
} }
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec( return m.DB.Exec(
@ -678,6 +732,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
}) })
} }
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) { func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return return
@ -704,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
Statement: &gorm.Statement{DB: m.DB, Dest: value}, Statement: &gorm.Statement{DB: m.DB, Dest: value},
} }
beDependedOn := map[*schema.Schema]bool{} beDependedOn := map[*schema.Schema]bool{}
if err := dep.Parse(value); err != nil { // support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
} }
if _, ok := parsedSchemas[dep.Statement.Schema]; ok { if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
@ -781,6 +837,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
return return
} }
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil { if stmt.TableExpr != nil {
return *stmt.TableExpr return *stmt.TableExpr

View File

@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
} }
type PreparedStmtTX struct { type PreparedStmtTX struct {
*sql.Tx Tx
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock() defer tx.PreparedStmtDB.Mux.Unlock()

147
scan.go
View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
for idx, name := range columns { for idx, name := range columns {
@ -49,65 +50,56 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns { for idx, field := range fields {
if sch == nil { if field != nil {
values[idx] = reflectValue.Interface() values[idx] = field.NewValuePool.Get()
} else if field := sch.LookUpField(column); field != nil && field.Readable { } else if len(fields) == 1 {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() if reflectValue.CanAddr() {
} else if names := strings.Split(column, "__"); len(names) > 1 { values[idx] = reflectValue.Addr().Interface()
if rel, ok := sch.Relationships.Relations[names[0]]; ok { } else {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { values[idx] = reflectValue.Interface()
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
} }
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
sch = nil
values[idx] = reflectValue.Interface()
} else {
values[idx] = &sql.RawBytes{}
} }
} }
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
if sch != nil { for idx, field := range fields {
for idx, column := range columns { if field != nil {
if field := sch.LookUpField(column); field != nil && field.Readable { if len(joinFields) == 0 || joinFields[idx][0] == nil {
field.Set(reflectValue, values[idx]) db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else if names := strings.Split(column, "__"); len(names) > 1 { } else {
if rel, ok := sch.Relationships.Relations[names[0]]; ok { relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue := rel.Field.ReflectValueOf(reflectValue) if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
value := reflect.ValueOf(values[idx]).Elem() continue
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} }
relValue.Set(reflect.New(relValue.Type().Elem()))
} }
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
} }
// release data to pool
field.NewValuePool.Put(values[idx])
} }
} }
} }
// ScanMode scan data mode
type ScanMode uint8 type ScanMode uint8
// scan modes
const ( const (
ScanInitialized ScanMode = 1 << 0 // 1 ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2 ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
) )
func Scan(rows *sql.Rows, db *DB, mode ScanMode) { // Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var ( var (
columns, _ = rows.Columns() columns, _ = rows.Columns()
values = make([]interface{}, len(columns)) values = make([]interface{}, len(columns))
@ -138,7 +130,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
} }
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
} }
case *[]map[string]interface{}, []map[string]interface{}: case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() { for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns) prepareValues(values, db, columnTypes, columns)
@ -149,11 +141,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
mapValue := map[string]interface{}{} mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns) scanIntoMap(mapValue, values, columns)
if values, ok := dest.([]map[string]interface{}); ok { *dest = append(*dest, mapValue)
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
} }
case *int, *int8, *int16, *int32, *int64, case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
@ -168,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
} }
default: default:
var ( var (
fields = make([]*schema.Field, len(columns)) fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field selectedColumnsMap = make(map[string]int, len(columns))
sch = db.Statement.Schema joinFields [][2]*schema.Field
reflectValue = db.Statement.ReflectValue sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
) )
if reflectValue.Kind() == reflect.Interface { if reflectValue.Kind() == reflect.Interface {
@ -193,35 +182,49 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
} }
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 { if len(columns) == 1 {
// isPluck // Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil sch = nil
} }
} }
// Not Pluck
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok {
for fieldIndex, selectField := range sch.Fields[curIndex+1:] {
if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1
fields[idx] = selectField
break
}
}
} else {
fields[idx] = field
selectedColumnsMap[column] = idx
}
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
} }
switch reflectValue.Kind() { switch reflectValue.Kind() {
@ -244,7 +247,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflectValue.Index(int(db.RowsAffected)) elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing { if onConflictDonothing {
for _, field := range fields { for _, field := range fields {
if _, ok := field.ValueOf(elem); !ok { if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++ db.RowsAffected++
goto BEGIN goto BEGIN
} }
@ -254,7 +257,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflect.New(reflectValueType) elem = reflect.New(reflectValueType)
} }
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update { if !update {
if isPtr { if isPtr {
@ -270,7 +273,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
} }
case reflect.Struct, reflect.Ptr: case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() { if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
} }
default: default:
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
@ -11,15 +12,25 @@ import (
"time" "time"
"github.com/jinzhu/now" "github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
type DataType string // special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type TimeType int64 type (
// DataType GORM data type
var TimeReflectType = reflect.TypeOf(time.Time{}) DataType string
// TimeType GORM time type
TimeType int64
)
// GORM time types
const ( const (
UnixTime TimeType = 1 UnixTime TimeType = 1
UnixSecond TimeType = 2 UnixSecond TimeType = 2
@ -27,6 +38,7 @@ const (
UnixNanosecond TimeType = 4 UnixNanosecond TimeType = 4
) )
// GORM fields types
const ( const (
Bool DataType = "bool" Bool DataType = "bool"
Int DataType = "int" Int DataType = "int"
@ -37,6 +49,7 @@ const (
Bytes DataType = "bytes" Bytes DataType = "bytes"
) )
// Field is the representation of model schema's field
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
@ -49,9 +62,9 @@ type Field struct {
Creatable bool Creatable bool
Updatable bool Updatable bool
Readable bool Readable bool
HasDefaultValue bool
AutoCreateTime TimeType AutoCreateTime TimeType
AutoUpdateTime TimeType AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string DefaultValue string
DefaultValueInterface interface{} DefaultValueInterface interface{}
NotNull bool NotNull bool
@ -60,6 +73,7 @@ type Field struct {
Size int Size int
Precision int Precision int
Scale int Scale int
IgnoreMigration bool
FieldType reflect.Type FieldType reflect.Type
IndirectFieldType reflect.Type IndirectFieldType reflect.Type
StructField reflect.StructField StructField reflect.StructField
@ -68,27 +82,39 @@ type Field struct {
Schema *Schema Schema *Schema
EmbeddedSchema *Schema EmbeddedSchema *Schema
OwnerSchema *Schema OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool) ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error Set func(context.Context, reflect.Value, interface{}) error
IgnoreMigration bool Serializer SerializerInterface
NewValuePool FieldNewValuePool
} }
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{ field := &Field{
Name: fieldStruct.Name, Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct, StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true, Creatable: true,
Updatable: true, Updatable: true,
Readable: true, Readable: true,
Tag: fieldStruct.Tag, PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
Schema: schema, HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1, AutoIncrementIncrement: 1,
} }
@ -97,7 +123,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer) valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer { if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
@ -105,31 +131,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v) fieldValue = reflect.ValueOf(v)
} }
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value) var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) { getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v) var (
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { rv = reflect.Indirect(v)
for i := 0; i < rv.Type().NumField(); i++ { rvType = rv.Type()
newFieldType := rv.Type().Field(i).Type )
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem() newFieldType = newFieldType.Elem()
} }
fieldValue = reflect.New(newFieldType) fieldValue = reflect.New(newFieldType)
if rvType != reflect.Indirect(fieldValue).Type() {
if rv.Type() != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue) getRealFieldValue(fieldValue)
} }
if fieldValue.IsValid() { if fieldValue.IsValid() {
return return
} }
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
} }
} }
} }
@ -138,19 +170,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if dbName, ok := field.TagSettings["COLUMN"]; ok { if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DBName = dbName field.DataType = String
} field.Serializer = v
} else {
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { var serializerName = field.TagSettings["JSON"]
field.PrimaryKey = true if serializerName == "" {
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { serializerName = field.TagSettings["SERIALIZER"]
field.PrimaryKey = true }
} if serializerName != "" {
if serializer, ok := GetSerializer(serializerName); ok {
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { // Set default data type to string for serializer
field.AutoIncrement = true field.DataType = String
field.HasDefaultValue = true field.Serializer = serializer
} else {
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
}
}
} }
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
@ -176,20 +212,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Scale, _ = strconv.Atoi(s) field.Scale, _ = strconv.Atoi(s)
} }
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
} else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
}
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
field.Unique = true
}
if val, ok := field.TagSettings["COMMENT"]; ok {
field.Comment = val
}
// default value is function or null or blank (primary keys) // default value is function or null or blank (primary keys)
field.DefaultValue = strings.TrimSpace(field.DefaultValue) field.DefaultValue = strings.TrimSpace(field.DefaultValue)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
@ -225,7 +247,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
case reflect.String: case reflect.String:
field.DataType = String field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue { if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
@ -236,22 +257,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) { } else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time field.DataType = Time
} }
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
if t, err := now.Parse(field.DefaultValue); err == nil {
field.DefaultValueInterface = t
}
}
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
field.DataType = Bytes field.DataType = Bytes
} }
} }
field.GORMDataType = field.DataType
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType()) field.DataType = DataType(dataTyper.GormDataType())
} }
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time { if field.DataType == Time {
field.AutoCreateTime = UnixTime field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" { } else if strings.ToUpper(v) == "NANO" {
@ -263,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time { if field.DataType == Time {
field.AutoUpdateTime = UnixTime field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" { } else if strings.ToUpper(v) == "NANO" {
@ -275,6 +299,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if val, ok := field.TagSettings["TYPE"]; ok { if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) { switch DataType(strings.ToLower(val)) {
case Bool, Int, Uint, Float, String, Time, Bytes: case Bool, Int, Uint, Float, String, Time, Bytes:
@ -284,10 +312,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if field.Size == 0 { if field.Size == 0 {
switch reflect.Indirect(fieldValue).Kind() { switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
@ -346,8 +370,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && // Normal anonymous field or having `EMBEDDED` tag
(ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind() kind := reflect.Indirect(fieldValue).Kind()
switch kind { switch kind {
case reflect.Struct: case reflect.Struct:
@ -410,31 +435,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// ValueOf // Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero() return fieldValue.Interface(), fieldValue.IsZero()
} }
default: default:
field.ValueOf = func(value reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value) v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
for _, idx := range field.StructField.Index { if fieldIdx >= 0 {
if idx >= 0 { v = v.Field(fieldIdx)
v = v.Field(idx)
} else { } else {
v = v.Field(-idx - 1) v = v.Field(-fieldIdx - 1)
if v.Type().Elem().Kind() != reflect.Struct {
return nil, true
}
if !v.IsNil() { if !v.IsNil() {
v = v.Elem() v = v.Elem()
@ -443,35 +462,52 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
} }
return v.Interface(), v.IsZero()
fv, zero := v.Interface(), v.IsZero()
return fv, zero
} }
} }
// ReflectValueOf if field.Serializer != nil {
switch { oldValuerOf := field.ValueOf
case len(field.StructField.Index) == 1: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
field.ReflectValueOf = func(value reflect.Value) reflect.Value { value, zero := oldValuerOf(ctx, v)
return reflect.Indirect(value).Field(field.StructField.Index[0]) if zero {
return value, zero
}
s, ok := value.(SerializerValuerInterface)
if !ok {
s = field.Serializer
}
return &serializer{
Field: field,
SerializeValuer: s,
Destination: v,
Context: ctx,
fieldValue: value,
}, false
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: }
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) // ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
} }
default: default:
field.ReflectValueOf = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v := reflect.Indirect(value) v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index { for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 { if fieldIdx >= 0 {
v = v.Field(fieldIdx) v = v.Field(fieldIdx)
} else { } else {
v = v.Field(-fieldIdx - 1) v = v.Field(-fieldIdx - 1)
}
if v.Kind() == reflect.Ptr { if v.IsNil() {
if v.Type().Elem().Kind() == reflect.Struct { v.Set(reflect.New(v.Type().Elem()))
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
} }
if idx < len(field.StructField.Index)-1 { if idx < len(field.StructField.Index)-1 {
@ -483,22 +519,25 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
if v == nil { if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
// Optimal value type acquisition for v // Optimal value type acquisition for v
reflectValType := reflectV.Type() reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) { if reflectValType.AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
reflectV = reflect.Indirect(reflectV)
}
field.ReflectValueOf(ctx, value).Set(reflectV)
return return
} else if reflectValType.ConvertibleTo(field.FieldType) { } else if reflectValType.ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
return return
} else if field.FieldType.Kind() == reflect.Ptr { } else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
fieldType := field.FieldType.Elem() fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(fieldType) { if reflectValType.AssignableTo(fieldType) {
@ -521,16 +560,19 @@ func (field *Field) setupValuerAndSetter() {
if reflectV.Kind() == reflect.Ptr { if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() { if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
return
} else { } else {
err = setter(value, reflectV.Elem().Interface()) err = setter(ctx, value, reflectV.Elem().Interface())
} }
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = setter(value, v) err = setter(ctx, value, v)
} }
} else { } else if _, ok := v.(clause.Expr); !ok {
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
} }
} }
@ -540,191 +582,201 @@ func (field *Field) setupValuerAndSetter() {
// Set // Set
switch field.FieldType.Kind() { switch field.FieldType.Kind() {
case reflect.Bool: case reflect.Bool:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case **bool:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data)
}
case bool: case bool:
field.ReflectValueOf(value).SetBool(data) field.ReflectValueOf(ctx, value).SetBool(data)
case *bool:
if data != nil {
field.ReflectValueOf(value).SetBool(*data)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case int64: case int64:
if data > 0 { field.ReflectValueOf(ctx, value).SetBool(data > 0)
field.ReflectValueOf(value).SetBool(true)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case string: case string:
b, _ := strconv.ParseBool(data) b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b) field.ReflectValueOf(ctx, value).SetBool(b)
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case **int64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data)
}
case int64: case int64:
field.ReflectValueOf(value).SetInt(data) field.ReflectValueOf(ctx, value).SetInt(data)
case int: case int:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int8: case int8:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int16: case int16:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int32: case int32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint: case uint:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint64: case uint64:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float32: case float32:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float64: case float64:
field.ReflectValueOf(value).SetInt(int64(data)) field.ReflectValueOf(ctx, value).SetInt(int64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case string: case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil { if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetInt(i) field.ReflectValueOf(ctx, value).SetInt(i)
} else { } else {
return err return err
} }
case time.Time: case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
case *time.Time: case *time.Time:
if data != nil { if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano()) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else { } else {
field.ReflectValueOf(value).SetInt(data.Unix()) field.ReflectValueOf(ctx, value).SetInt(data.Unix())
} }
} else { } else {
field.ReflectValueOf(value).SetInt(0) field.ReflectValueOf(ctx, value).SetInt(0)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case **uint64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data)
}
case uint64: case uint64:
field.ReflectValueOf(value).SetUint(data) field.ReflectValueOf(ctx, value).SetUint(data)
case uint: case uint:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int64: case int64:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int: case int:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int8: case int8:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int16: case int16:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int32: case int32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float32: case float32:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float64: case float64:
field.ReflectValueOf(value).SetUint(uint64(data)) field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case time.Time: case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else { } else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix())) field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
} }
case string: case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil { if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i) field.ReflectValueOf(ctx, value).SetUint(i)
} else { } else {
return err return err
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case **float64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data)
}
case float64: case float64:
field.ReflectValueOf(value).SetFloat(data) field.ReflectValueOf(ctx, value).SetFloat(data)
case float32: case float32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int64: case int64:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int: case int:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int8: case int8:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int16: case int16:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int32: case int32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint: case uint:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint8: case uint8:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint16: case uint16:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint32: case uint32:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint64: case uint64:
field.ReflectValueOf(value).SetFloat(float64(data)) field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case []byte: case []byte:
return field.Set(value, string(data)) return field.Set(ctx, value, string(data))
case string: case string:
if i, err := strconv.ParseFloat(data, 64); err == nil { if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(value).SetFloat(i) field.ReflectValueOf(ctx, value).SetFloat(i)
} else { } else {
return err return err
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
case reflect.String: case reflect.String:
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case **string:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data)
}
case string: case string:
field.ReflectValueOf(value).SetString(data) field.ReflectValueOf(ctx, value).SetString(data)
case []byte: case []byte:
field.ReflectValueOf(value).SetString(string(data)) field.ReflectValueOf(ctx, value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(utils.ToString(data)) field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
case float64, float32: case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return err return err
} }
@ -732,41 +784,49 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.Set(ctx, value, *data)
}
case time.Time: case time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
if data != nil { if data != nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
} else { } else {
field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
} }
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(t)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
case *time.Time: case *time.Time:
field.Set = func(value reflect.Value, v interface{}) error { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case **time.Time:
if data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
}
case time.Time: case time.Time:
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
fieldValue.Elem().Set(reflect.ValueOf(v)) fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
if v == "" { if v == "" {
return nil return nil
@ -778,27 +838,27 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
return nil return nil
} }
default: default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
return field.Set(value, reflectV.Elem().Interface()) return field.Set(ctx, value, reflectV.Elem().Interface())
} }
} else { } else {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
@ -813,32 +873,80 @@ func (field *Field) setupValuerAndSetter() {
} }
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok { } else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner // struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) { } else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() { if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
return field.Set(value, reflectV.Elem().Interface()) return field.Set(ctx, value, reflectV.Elem().Interface())
} }
} else { } else {
if valuer, ok := v.(driver.Valuer); ok { if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value() v, _ = valuer.Value()
} }
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else { } else {
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set) return fallbackSetter(ctx, value, v, field.Set)
} }
} }
} }
} }
if field.Serializer != nil {
var (
oldFieldSetter = field.Set
sameElemType bool
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
)
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
}
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
if s, ok := v.(*serializer); ok {
if s.fieldValue != nil {
err = oldFieldSetter(ctx, value, s.fieldValue)
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
} else if sameType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
}
}
} else {
err = oldFieldSetter(ctx, value, v)
}
return
}
}
}
func (field *Field) setupNewValuePool() {
if field.Serializer != nil {
field.NewValuePool = &sync.Pool{
New: func() interface{} {
return &serializer{
Field: field,
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
}
},
}
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
} }

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"sync" "sync"
@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"fmt"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range parseFieldIndexes(field) { fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexes[index.Name] idx := indexes[index.Name]
idx.Name = index.Name idx.Name = index.Name
if idx.Class == "" { if idx.Class == "" {
@ -82,18 +88,19 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil return nil
} }
func parseFieldIndexes(field *Field) (indexes []Index) { func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" { if value != "" {
v := strings.Split(value, ":") v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0])) k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" { if k == "INDEX" || k == "UNIQUEINDEX" {
var ( var (
name string name string
tag = strings.Join(v[1:], ":") tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",") idx = strings.Index(tag, ",")
settings = ParseTagSetting(tag, ",") tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
length, _ = strconv.Atoi(settings["LENGTH"]) settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
) )
if idx == -1 { if idx == -1 {
@ -105,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
} }
if name == "" { if name == "" {
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"The composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
} }
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
@ -137,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
} }
} }
err = nil
return return
} }

View File

@ -18,6 +18,37 @@ type UserIndex struct {
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"` MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
@ -78,6 +109,41 @@ func TestParseIndex(t *testing.T) {
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
}, },
"type": {
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
"idx_user_indices_comp_id0": {
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
"idx_user_indices_comp_id1": {
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
"idx_user_indices_comp_id2": {
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
},
} }
indices := user.ParseIndexes() indices := user.ParseIndexes()

View File

@ -4,22 +4,33 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDataType() string GormDataType() string
} }
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface { type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface CreateClauses(*Field) []clause.Interface
} }
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface { type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface QueryClauses(*Field) []clause.Interface
} }
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface { type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface UpdateClauses(*Field) []clause.Interface
} }
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface { type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface DeleteClauses(*Field) []clause.Interface
} }

View File

@ -3,7 +3,6 @@ package schema
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"fmt"
"regexp" "regexp"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -86,16 +85,16 @@ func (ns NamingStrategy) IndexName(table, column string) string {
} }
func (ns NamingStrategy) formatName(prefix, table, name string) string { func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.Replace(strings.Join([]string{ formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name, prefix, table, name,
}, "_"), ".", "_", -1) }, "_"), ".", "_")
if utf8.RuneCountInString(formattedName) > 64 { if utf8.RuneCountInString(formattedName) > 64 {
h := sha1.New() h := sha1.New()
h.Write([]byte(formattedName)) h.Write([]byte(formattedName))
bs := h.Sum(nil) bs := h.Sum(nil)
formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
} }
return formattedName return formattedName
} }
@ -120,7 +119,13 @@ func (ns NamingStrategy) toDBName(name string) string {
} }
if ns.NameReplacer != nil { if ns.NameReplacer != nil {
name = ns.NameReplacer.Replace(name) tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
} }
if ns.NoLowerCase { if ns.NoLowerCase {
@ -168,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string {
} }
func (ns NamingStrategy) toSchemaName(name string) string { func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
} }

View File

@ -193,7 +193,18 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{} ns := NamingStrategy{}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName) t.Errorf("invalid formatted name generated, got %v", formattedName)
} }
} }
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

19
schema/pool.go Normal file
View File

@ -0,0 +1,19 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,6 +1,7 @@
package schema package schema
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -234,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName, Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath, PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type, Type: ownField.StructField.Type,
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
@ -257,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName, Name: joinFieldName,
PkgPath: relField.StructField.PkgPath, PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type, Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
@ -415,6 +418,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} else { } else {
var primaryFields []*Field var primaryFields []*Field
var primarySchemaName = primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 { if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys { for _, primaryKey := range relation.primaryKeys {
@ -427,7 +434,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
for _, primaryField := range primaryFields { for _, primaryField := range primaryFields {
lookUpName := primarySchema.Name + primaryField.Name lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs { if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name lookUpName = field.Name + primaryField.Name
} }
@ -576,7 +583,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
return &constraint return &constraint
} }
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table table := rel.FieldSchema.Table
foreignFields := []*Field{} foreignFields := []*Field{}
relForeignKeys := []string{} relForeignKeys := []string{}
@ -616,7 +623,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
} }
} }
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues) column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})

View File

@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) {
} }
} }
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) { func TestSameForeignKey(t *testing.T) {
type UserAux struct { type UserAux struct {
gorm.Model gorm.Model
@ -576,3 +596,39 @@ func TestHasManySameForeignKey(t *testing.T) {
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
}) })
} }
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {
fv, _ := s.FieldsByDBName[k].ValueOf(value) fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value)
tests.AssertEqual(t, v, fv) tests.AssertEqual(t, v, fv)
}) })
} }

157
schema/serializer.go Normal file
View File

@ -0,0 +1,157 @@
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value interface{}
fieldValue interface{}
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value interface{}) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}
// JSONSerializer json serializer
type JSONSerializer struct {
}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
err = json.Unmarshal(bytes, fieldValue.Interface())
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct {
}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct {
}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}

View File

@ -1,6 +1,8 @@
package schema package schema
import ( import (
"context"
"fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@ -58,14 +60,22 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
return tag return tag
} }
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value // GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels { for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) { appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero { if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value)) result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() { switch result.Kind() {
case reflect.Struct: case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr()) reflectResults = reflect.Append(reflectResults, result.Addr())
@ -97,7 +107,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
} }
// GetIdentityFieldValuesMap get identity map from fields // GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var ( var (
results = [][]interface{}{} results = [][]interface{}{}
dataResults = map[string][]reflect.Value{} dataResults = map[string][]reflect.Value{}
@ -110,7 +120,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
results = [][]interface{}{make([]interface{}, len(fields))} results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields { for idx, field := range fields {
results[0][idx], zero = field.ValueOf(reflectValue) results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -135,7 +145,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
fieldValues := make([]interface{}, len(fields)) fieldValues := make([]interface{}, len(fields))
notZero = false notZero = false
for idx, field := range fields { for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(elem) fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -155,12 +165,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
} }
// GetIdentityFieldValuesMapFromValues get identity map from fields // GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{} resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{} results := [][]interface{}{}
for _, v := range values { for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm { for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...) resultsMap[k] = append(resultsMap[k], v...)
} }

View File

@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { SoftDeleteQueryClause(sd).ModifyStatement(stmt)
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
} }
} }
@ -135,7 +133,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
stmt.SetColumn(sd.Field.DBName, curTime, true) stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil { if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -143,7 +141,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
} }
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
stmt.AddClauseIfNotExists(clause.Update{}) stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...) stmt.Build(stmt.DB.Callback().Update().Clauses...)
} }

View File

@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(') writer.WriteByte('(')
for idx, d := range v { for idx, d := range v {
if idx > 0 { if idx > 0 {
writer.WriteString(",") writer.WriteByte(',')
} }
stmt.QuoteTo(writer, d) stmt.QuoteTo(writer, d)
} }
@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(') writer.WriteByte('(')
for idx, d := range v { for idx, d := range v {
if idx > 0 { if idx > 0 {
writer.WriteString(",") writer.WriteByte(',')
} }
stmt.DB.Dialector.QuoteTo(writer, d) stmt.DB.Dialector.QuoteTo(writer, d)
} }
@ -179,9 +179,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
} }
case clause.Expr: case clause.Expression:
v.Build(stmt)
case *clause.Expr:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:
stmt.Vars = append(stmt.Vars, v) stmt.Vars = append(stmt.Vars, v)
@ -314,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB: case *DB:
for _, scope := range v.Statement.scopes {
v = scope(v)
}
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok { if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 { if len(where.Exprs) == 1 {
@ -391,7 +393,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
@ -405,7 +407,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
@ -564,7 +566,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
switch destValue.Kind() { switch destValue.Kind() {
case reflect.Struct: case reflect.Struct:
field.Set(destValue, value) stmt.AddError(field.Set(stmt.Context, destValue, value))
default: default:
stmt.AddError(ErrInvalidData) stmt.AddError(ErrInvalidData)
} }
@ -574,10 +576,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 { if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
} }
} else { } else {
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
} }
case reflect.Struct: case reflect.Struct:
if !stmt.ReflectValue.CanAddr() { if !stmt.ReflectValue.CanAddr() {
@ -585,7 +587,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
return return
} }
field.Set(stmt.ReflectValue, value) stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
} }
} else { } else {
stmt.AddError(ErrInvalidField) stmt.AddError(ErrInvalidField)
@ -605,12 +607,12 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool { changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(modelValue) fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := v[field.Name]; ok { if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok { } else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} }
} else { } else {
@ -619,7 +621,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
destValue = destValue.Elem() destValue = destValue.Elem()
} }
changedValue, zero := field.ValueOf(destValue) changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue) return !zero && !utils.AssertEqual(changedValue, fieldValue)
} }
} }

View File

@ -220,3 +220,67 @@ func TestFullSaveAssociations(t *testing.T) {
t.Errorf("Failed to preload AppliesToProduct") t.Errorf("Failed to preload AppliesToProduct")
} }
} }
func TestSaveBelongsCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent}
DB.Create(&child)
parent.FavChildID = child.ID
parent.FavChild = &child
DB.Save(&parent)
var parent1 Parent
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
// Save and Updates is the same
DB.Updates(&parent)
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
}
func TestSaveHasManyCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
parent.Children = []*Child{&child, &child1}
DB.Save(&parent)
var children []*Child
DB.Where("parent_id = ?", parent.ID).Find(&children)
if len(children) != len(parent.Children) ||
children[0].ID != parent.Children[0].ID ||
children[1].ID != parent.Children[1].ID {
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
children, parent.Children)
}
}
func TestAssociationError(t *testing.T) {
user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2})
DB.Create(&user)
var user1 User
DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1)
var emptyUser User
var err error
// belongs to
err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// has many
err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// has one
err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// many to many
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
}

View File

@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
func c3(*gorm.DB) {} func c3(*gorm.DB) {}
func c4(*gorm.DB) {} func c4(*gorm.DB) {}
func c5(*gorm.DB) {} func c5(*gorm.DB) {}
func c6(*gorm.DB) {}
func TestCallbacks(t *testing.T) { func TestCallbacks(t *testing.T) {
type callback struct { type callback struct {
@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) {
} }
} }
} }
func TestPluginCallbacks(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("plugin_1_fn1", c1)
createCallback.After("*").Register("plugin_1_fn2", c2)
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 2
createCallback.Before("*").Register("plugin_2_fn1", c3)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_2_fn2", c4)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 3
createCallback.Before("*").Register("plugin_3_fn1", c5)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_3_fn2", c6)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

View File

@ -9,7 +9,7 @@ import (
) )
func TestWithSingleConnection(t *testing.T) { func TestWithSingleConnection(t *testing.T) {
var expectedName = "test" expectedName := "test"
var actualName string var actualName string
setSQL, getSQL := getSetSQL(DB.Dialector.Name()) setSQL, getSQL := getSetSQL(DB.Dialector.Name())
@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) {
} }
return nil return nil
}) })
if err != nil { if err != nil {
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
} }

171
tests/connpool_test.go Normal file
View File

@ -0,0 +1,171 @@
package tests_test
import (
"context"
"database/sql"
"os"
"reflect"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
type wrapperTx struct {
*sql.Tx
conn *wrapperConnPool
}
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.PrepareContext(ctx, query)
}
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.ExecContext(ctx, query, args...)
}
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryContext(ctx, query, args...)
}
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryRowContext(ctx, query, args...)
}
type wrapperConnPool struct {
db *sql.DB
got []string
expect []string
}
func (c *wrapperConnPool) Ping() error {
return c.db.Ping()
}
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
// return c.db.BeginTx(ctx, opts)
// }
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
tx, err := c.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &wrapperTx{Tx: tx, conn: c}, nil
}
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.got = append(c.got, query)
return c.db.PrepareContext(ctx, query)
}
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.got = append(c.got, query)
return c.db.ExecContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.got = append(c.got, query)
return c.db.QueryContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.got = append(c.got, query)
return c.db.QueryRowContext(ctx, query, args...)
}
func TestConnPoolWrapper(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect != "mysql" {
t.SkipNow()
}
dbDSN := os.Getenv("GORM_DSN")
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
}
nativeDB, err := sql.Open("mysql", dbDSN)
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
conn := &wrapperConnPool{
db: nativeDB,
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
},
}
defer func() {
if !reflect.DeepEqual(conn.got, conn.expect) {
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
}
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
tx := db.Begin()
user := *GetUser("transaction", Config{})
if err = tx.Save(&user).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
user1 := *GetUser("transaction1-1", Config{})
if err = tx.Save(&user1).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
t.Fatalf("Should not find record after rollback, but got %v", err)
}
txDB := db.Where("fake_name = ?", "fake_name")
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
user2 := *GetUser("transaction-2", Config{})
if err = tx2.Save(&user2).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
tx2.Commit()
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err)
}
}

View File

@ -144,4 +144,22 @@ func TestCount(t *testing.T) {
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
} }
var count12 int64
if err := DB.Table("users").
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name")
}).Count(&count12).Error; err == nil {
t.Errorf("error should raise when using preload without schema")
}
var count13 int64
if err := DB.Model(User{}).
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name")
}).Count(&count13).Error; err != nil {
t.Errorf("no error should raise when using count with preload, but got %v", err)
}
} }

View File

@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
{"name": "create_from_map_3", "Age": 20}, {"name": "create_from_map_3", "Age": 20},
} }
if err := DB.Model(&User{}).Create(datas).Error; err != nil { if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err) t.Fatalf("failed to create data from slice of map, got error: %v", err)
} }
@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) {
t.Fatalf("it is not ErrInvalidValue") t.Fatalf("it is not ErrInvalidValue")
} }
} }
func TestFirstOrCreateRowsAffected(t *testing.T) {
user := User{Name: "TestFirstOrCreateRowsAffected"}
res := DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 1 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
res = DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 0 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
}

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"testing" "testing"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -9,12 +10,13 @@ import (
func TestDefaultValue(t *testing.T) { func TestDefaultValue(t *testing.T) {
type Harumph struct { type Harumph struct {
gorm.Model gorm.Model
Email string `gorm:"not null;index:,unique"` Email string `gorm:"not null;index:,unique"`
Name string `gorm:"notNull;default:foo"` Name string `gorm:"notNull;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"` Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;notNull;default:''"` Name3 string `gorm:"size:233;notNull;default:''"`
Age int `gorm:"default:18"` Age int `gorm:"default:18"`
Enabled bool `gorm:"default:true"` Created time.Time `gorm:"default:2000-01-02"`
Enabled bool `gorm:"default:true"`
} }
DB.Migrator().DropTable(&Harumph{}) DB.Migrator().DropTable(&Harumph{})
@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) {
harumph := Harumph{Email: "hello@gorm.io"} harumph := Harumph{Email: "hello@gorm.io"}
if err := DB.Create(&harumph).Error; err != nil { if err := DB.Create(&harumph).Error; err != nil {
t.Fatalf("Failed to create data with default value, got error: %v", err) t.Fatalf("Failed to create data with default value, got error: %v", err)
} else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to create data with default value, got: %+v", harumph) t.Fatalf("Failed to create data with default value, got: %+v", harumph)
} }
var result Harumph var result Harumph
if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil {
t.Fatalf("Failed to find created data, got error: %v", err) t.Fatalf("Failed to find created data, got error: %v", err)
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to find created data with default data, got %+v", result) t.Fatalf("Failed to find created data with default data, got %+v", result)
} }
} }

View File

@ -2,7 +2,7 @@ version: '3'
services: services:
mysql: mysql:
image: 'mysql:latest' image: 'mysql/mysql-server:latest'
ports: ports:
- 9910:3306 - 9910:3306
environment: environment:
@ -20,7 +20,7 @@ services:
- POSTGRES_USER=gorm - POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm - POSTGRES_PASSWORD=gorm
mssql: mssql:
image: 'mcmoe/mssqldocker:latest' image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
ports: ports:
- 9930:1433 - 9930:1433
environment: environment:

View File

@ -3,16 +3,16 @@ module gorm.io/gorm/tests
go 1.14 go 1.14
require ( require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.14.1 // indirect github.com/jinzhu/now v1.1.5
github.com/jinzhu/now v1.1.4 github.com/lib/pq v1.10.5
github.com/lib/pq v1.10.4 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect gorm.io/driver/mysql v1.3.3
gorm.io/driver/mysql v1.2.3 gorm.io/driver/postgres v1.3.5
gorm.io/driver/postgres v1.2.3 gorm.io/driver/sqlite v1.3.2
gorm.io/driver/sqlite v1.2.6 gorm.io/driver/sqlserver v1.3.2
gorm.io/driver/sqlserver v1.2.1 gorm.io/gorm v1.23.4
gorm.io/gorm v1.22.4
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -19,6 +19,7 @@ type Config struct {
Team int Team int
Languages int Languages int
Friends int Friends int
NamedPet bool
} }
func GetUser(name string, config Config) *User { func GetUser(name string, config Config) *User {
@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User {
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
} }
if config.NamedPet {
user.NamedPet = &Pet{Name: name + "_namepet"}
}
return &user return &user
} }

View File

@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) {
t.Errorf("invalid data after update, got %+v", product) t.Errorf("invalid data after update, got %+v", product)
} }
// Code changed, price should changed
DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""})
if product.Name != "Product New4" || product.Price != 320 || product.Code != "" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
if product.Price != 270 || product.Code != "L1215" { if product.Price != 320 || product.Code != "L1215" {
t.Errorf("invalid data after update, got %+v", product) t.Errorf("invalid data after update, got %+v", product)
} }
DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"})
if product.Price != 270 || product.Code != "L1216" { if product.Price != 320 || product.Code != "L1216" {
t.Errorf("invalid data after update, got %+v", product) t.Errorf("invalid data after update, got %+v", product)
} }
@ -460,8 +466,9 @@ type Product4 struct {
type ProductItem struct { type ProductItem struct {
gorm.Model gorm.Model
Code string Code string
Product4ID uint Product4ID uint
AfterFindCallTimes int
} }
func (pi ProductItem) BeforeCreate(*gorm.DB) error { func (pi ProductItem) BeforeCreate(*gorm.DB) error {
@ -471,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error {
return nil return nil
} }
func (pi *ProductItem) AfterFind(*gorm.DB) error {
pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1
return nil
}
func TestFailedToSaveAssociationShouldRollback(t *testing.T) { func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.Migrator().DropTable(&Product4{}, &ProductItem{})
DB.AutoMigrate(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{})
@ -492,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
t.Errorf("should find product, but got error %v", err) t.Errorf("should find product, but got error %v", err)
} }
var productWithItem Product4
if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil {
t.Errorf("should find product, but got error %v", err)
}
if productWithItem.Item.AfterFindCallTimes != 0 {
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
}
} }

View File

@ -10,12 +10,12 @@ import (
) )
func TestJoins(t *testing.T) { func TestJoins(t *testing.T) {
user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
DB.Create(&user) DB.Create(&user)
var user2 User var user2 User
if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err) t.Fatalf("Failed to load with joins, got error: %v", err)
} }
@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) {
} }
} }
func TestJoinWithOmit(t *testing.T) {
user := *GetUser("joins_with_omit", Config{Pets: 2})
DB.Save(&user)
results := make([]*User, 0)
if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil {
return
}
if len(results) != 2 || results[0].Name != "" || results[1].Name != "" {
t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results)
return
}
}
func TestJoinCount(t *testing.T) { func TestJoinCount(t *testing.T) {
companyA := Company{Name: "A"} companyA := Company{Name: "A"}
companyB := Company{Name: "B"} companyB := Company{Name: "B"}
@ -184,3 +200,32 @@ func TestJoinCount(t *testing.T) {
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
} }
} }
func TestJoinWithSoftDeleted(t *testing.T) {
user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true})
DB.Create(&user)
var user1 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID)
if user1.NamedPet == nil || user1.Account.ID == 0 {
t.Fatalf("joins NamedPet and Account should not empty:%v", user1)
}
// Account should empty
DB.Delete(&user1.Account)
var user2 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID)
if user2.NamedPet == nil || user2.Account.ID != 0 {
t.Fatalf("joins Account should not empty:%v", user2)
}
// NamedPet should empty
DB.Delete(&user1.NamedPet)
var user3 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID)
if user3.NamedPet != nil || user2.Account.ID != 0 {
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
}
}

View File

@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
func TestSetAndGet(t *testing.T) { func TestSetAndGet(t *testing.T) {
if value, ok := DB.Set("hello", "world").Get("hello"); !ok { if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
t.Errorf("Should be able to get setting after set") t.Errorf("Should be able to get setting after set")
} else { } else if value.(string) != "world" {
if value.(string) != "world" { t.Errorf("Set value should not be changed")
t.Errorf("Setted value should not be changed")
}
} }
if _, ok := DB.Get("non_existing"); ok { if _, ok := DB.Get("non_existing"); ok {

View File

@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) {
for _, m := range allModels { for _, m := range allModels {
if !DB.Migrator().HasTable(m) { if !DB.Migrator().HasTable(m) {
t.Fatalf("Failed to create table for %#v---", m) t.Fatalf("Failed to create table for %#v", m)
} }
} }
@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
} }
func TestSmartMigrateColumn(t *testing.T) { func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct { type UserMigrateColumn struct {
ID uint ID uint
@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) {
DB.Migrator().DropTable("new_table_structs") DB.Migrator().DropTable("new_table_structs")
if DB.Migrator().HasTable(&NewTableStruct{}) { if DB.Migrator().HasTable(&NewTableStruct{}) {
t.Fatal("should not found droped table") t.Fatal("should not found dropped table")
} }
} }
@ -313,9 +313,16 @@ func TestMigrateIndexes(t *testing.T) {
} }
func TestMigrateColumns(t *testing.T) { func TestMigrateColumns(t *testing.T) {
sqlite := DB.Dialector.Name() == "sqlite"
sqlserver := DB.Dialector.Name() == "sqlserver"
type ColumnStruct struct { type ColumnStruct struct {
gorm.Model gorm.Model
Name string Name string
Age int `gorm:"default:18;comment:my age"`
Code string `gorm:"unique;comment:my code;"`
Code2 string
Code3 string `gorm:"unique"`
} }
DB.Migrator().DropTable(&ColumnStruct{}) DB.Migrator().DropTable(&ColumnStruct{})
@ -326,13 +333,20 @@ func TestMigrateColumns(t *testing.T) {
type ColumnStruct2 struct { type ColumnStruct2 struct {
gorm.Model gorm.Model
Name string `gorm:"size:100"` Name string `gorm:"size:100"`
Code string `gorm:"unique;comment:my code2;default:hello"`
Code2 string `gorm:"unique"`
// Code3 string
} }
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
t.Fatalf("no error should happened when alter column, but got %v", err) t.Fatalf("no error should happened when alter column, but got %v", err)
} }
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
}
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
t.Fatalf("no error should returns for ColumnTypes") t.Fatalf("no error should returns for ColumnTypes")
} else { } else {
@ -340,11 +354,45 @@ func TestMigrateColumns(t *testing.T) {
stmt.Parse(&ColumnStruct2{}) stmt.Parse(&ColumnStruct2{})
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
if columnType.Name() == "name" { switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
} }
if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code3":
// TODO
// if v, ok := columnType.Unique(); !ok || v {
// t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
// }
} }
} }
} }
@ -526,3 +574,122 @@ func TestMigrateColumnOrder(t *testing.T) {
} }
} }
} }
// https://github.com/go-gorm/gorm/issues/5047
func TestMigrateSerialColumn(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Event struct {
ID uint `gorm:"primarykey"`
UID uint32
}
type Event1 struct {
ID uint `gorm:"primarykey"`
UID uint32 `gorm:"not null;autoIncrement"`
}
type Event2 struct {
ID uint `gorm:"primarykey"`
UID uint16 `gorm:"not null;autoIncrement"`
}
var err error
err = DB.Migrator().DropTable(&Event{})
if err != nil {
t.Errorf("DropTable err:%v", err)
}
// create sequence
err = DB.Table("events").AutoMigrate(&Event1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// delete sequence
err = DB.Table("events").AutoMigrate(&Event{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// update sequence
err = DB.Table("events").AutoMigrate(&Event1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = DB.Table("events").AutoMigrate(&Event2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
DB.Table("events").Save(&Event2{})
DB.Table("events").Save(&Event2{})
DB.Table("events").Save(&Event2{})
events := make([]*Event, 0)
DB.Table("events").Find(&events)
AssertEqual(t, 3, len(events))
for _, v := range events {
AssertEqual(t, v.ID, v.UID)
}
}
// https://github.com/go-gorm/gorm/issues/5300
func TestMigrateWithSpecialName(t *testing.T) {
var err error
err = DB.AutoMigrate(&Coupon{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
AssertEqual(t, true, DB.Migrator().HasTable("coupons"))
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1"))
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
}
// https://github.com/go-gorm/gorm/issues/5320
func TestPrimarykeyID(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type MissPKLanguage struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
Name string
}
type MissPKUser struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
}
var err error
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("DropTable err:%v", err)
}
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
// patch
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
}

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq" "github.com/lib/pq"
@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) {
type Harumph struct { type Harumph struct {
gorm.Model gorm.Model
Name string `gorm:"check:name_checker,name <> ''"` Name string `gorm:"check:name_checker,name <> ''"`
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
Things pq.StringArray `gorm:"type:text[]"` CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
Things pq.StringArray `gorm:"type:text[]"`
} }
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) {
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err) t.Errorf("No error should happen, but got %v", err)
} }
harumph.Name = "jinzhu1"
if err := DB.Save(&harumph).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
} }
type Post struct { type Post struct {

View File

@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) {
} }
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
} }
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {

View File

@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
func TestPreloadWithDiffModel(t *testing.T) {
user := *GetUser("preload_with_diff_model", Config{Account: true})
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
var result struct {
Something string
User
}
DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select(
"users.*, 'yo' as something").First(&result, "name = ?", user.Name)
CheckUser(t, user, result.User)
}

View File

@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) {
} }
} }
func TestFindInBatchesWithOffsetLimit(t *testing.T) {
users := []User{
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
}
DB.Create(&users)
var (
sub, results []User
lastBatch int
)
// offset limit
if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error {
results = append(results, sub...)
lastBatch = batch
return nil
}); result.Error != nil || result.RowsAffected != 5 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
if lastBatch != 3 {
t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch)
}
targetUsers := users[3:8]
for i := 0; i < len(targetUsers); i++ {
AssertEqual(t, results[i], targetUsers[i])
}
var sub1 []User
// limit < batchSize
if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 5 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
var sub2 []User
// only offset
if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 7 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
var sub3 []User
if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 4 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
}
func TestFindInBatchesWithError(t *testing.T) { func TestFindInBatchesWithError(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlserver" { if name := DB.Dialector.Name(); name == "sqlserver" {
t.Skip("skip sqlserver due to it will raise data race for invalid sql") t.Skip("skip sqlserver due to it will raise data race for invalid sql")
@ -512,7 +574,13 @@ func TestNotWithAllFields(t *testing.T) {
func TestOr(t *testing.T) { func TestOr(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true}) dryDB := DB.Session(&gorm.Session{DryRun: true})
result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) var count int64
result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count)
if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
} }
@ -577,7 +645,9 @@ func TestPluck(t *testing.T) {
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil {
t.Errorf("got error when pluck name: %v", err) t.Errorf("got error when pluck name: %v", err)
} }
AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2)))
sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] })
AssertEqual(t, names, names2)
var ids []int var ids []int
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil {
@ -1152,3 +1222,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
} }
} }
type DoubleInt64 struct {
data int64
}
func (t *DoubleInt64) Scan(val interface{}) error {
switch v := val.(type) {
case int64:
t.data = v * 2
return nil
default:
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
}
}
// https://github.com/go-gorm/gorm/issues/5091
func TestQueryScannerWithSingleColumn(t *testing.T) {
user := User{Name: "scanner_raw_1", Age: 10}
DB.Create(&user)
var result1 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
"age", &result1).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result1.data, 20)
var result2 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
"age").Scan(&result2).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result2.data, 20)
}

View File

@ -10,6 +10,11 @@ import (
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
type PersonAddressInfo struct {
Person *Person `gorm:"embedded"`
Address *Address `gorm:"embedded"`
}
func TestScan(t *testing.T) { func TestScan(t *testing.T) {
user1 := User{Name: "ScanUser1", Age: 1} user1 := User{Name: "ScanUser1", Age: 1}
user2 := User{Name: "ScanUser2", Age: 10} user2 := User{Name: "ScanUser2", Age: 10}
@ -156,3 +161,57 @@ func TestScanRows(t *testing.T) {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
} }
} }
func TestScanToEmbedded(t *testing.T) {
person1 := Person{Name: "person 1"}
person2 := Person{Name: "person 2"}
DB.Save(&person1).Save(&person2)
address1 := Address{Name: "address 1"}
address2 := Address{Name: "address 2"}
DB.Save(&address1).Save(&address2)
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)})
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)})
DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)})
var personAddressInfoList []*PersonAddressInfo
if err := DB.Select("people.*, addresses.*").
Table("people").
Joins("inner join person_addresses on people.id = person_addresses.person_id").
Joins("inner join addresses on person_addresses.address_id = addresses.id").
Find(&personAddressInfoList).Error; err != nil {
t.Errorf("Failed to run join query, got error: %v", err)
}
personMatched := false
addressMatched := false
for _, info := range personAddressInfoList {
if info.Person == nil {
t.Fatalf("Failed, expected not nil, got person nil")
}
if info.Address == nil {
t.Fatalf("Failed, expected not nil, got address nil")
}
if info.Person.ID == person1.ID {
personMatched = true
if info.Person.Name != person1.Name {
t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name)
}
}
if info.Address.ID == address1.ID {
addressMatched = true
if info.Address.Name != address1.Name {
t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name)
}
}
}
if !personMatched {
t.Errorf("Failed, no person matched")
}
if !addressMatched {
t.Errorf("Failed, no address matched")
}
}

138
tests/serializer_test.go Normal file
View File

@ -0,0 +1,138 @@
package tests_test
import (
"bytes"
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests"
)
type SerializerStruct struct {
gorm.Model
Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"`
Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
EncryptedString EncryptedString
}
type Roles []string
type Job struct {
Title string
Number int
Location string
IsIntern bool
}
type EncryptedString string
func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
switch value := dbValue.(type) {
case []byte:
*es = EncryptedString(bytes.TrimPrefix(value, []byte("hello")))
case string:
*es = EncryptedString(strings.TrimPrefix(value, "hello"))
default:
return fmt.Errorf("unsupported data %#v", dbValue)
}
return nil
}
func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
return "hello" + string(es), nil
}
func TestSerializer(t *testing.T) {
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt := createdAt.Unix()
data := SerializerStruct{
Name: []byte("jinzhu"),
Roles: []string{"r1", "r2"},
Contracts: map[string]interface{}{"name": "jinzhu", "age": 10},
EncryptedString: EncryptedString("pass"),
CreatedTime: createdAt.Unix(),
UpdatedTime: &updatedAt,
JobInfo: Job{
Title: "programmer",
Number: 9920,
Location: "Kenmawr",
IsIntern: false,
},
}
if err := DB.Create(&data).Error; err != nil {
t.Fatalf("failed to create data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, data)
}
func TestSerializerAssignFirstOrCreate(t *testing.T) {
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
data := SerializerStruct{
Name: []byte("ag9920"),
Roles: []string{"r1", "r2"},
Contracts: map[string]interface{}{"name": "jing1", "age": 11},
EncryptedString: EncryptedString("pass"),
CreatedTime: createdAt.Unix(),
JobInfo: Job{
Title: "programmer",
Number: 9920,
Location: "Shadyside",
IsIntern: false,
},
}
// first time insert record
out := SerializerStruct{}
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, out)
//update record
data.Roles = append(data.Roles, "r3")
data.JobInfo.Location = "Gates Hillman Complex"
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result.Roles, data.Roles)
AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location)
}

View File

@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
} }
} }
type ageInt int8
func (ageInt) String() string {
return "age"
}
type ageBool bool
func (ageBool) String() string {
return "age"
}
type ageUint64 uint64
func (ageUint64) String() string {
return "age"
}
type ageFloat float64
func (ageFloat) String() string {
return "age"
}
func TestExplainSQL(t *testing.T) {
user := *GetUser("explain-sql", Config{})
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
}
func TestGroupConditions(t *testing.T) { func TestGroupConditions(t *testing.T) {
type Pizza struct { type Pizza struct {
ID uint ID uint
@ -190,6 +243,21 @@ func TestGroupConditions(t *testing.T) {
if !strings.HasSuffix(result, expects) { if !strings.HasSuffix(result, expects) {
t.Errorf("expects: %v, got %v", expects, result) t.Errorf("expects: %v, got %v", expects, result)
} }
stmt2 := dryRunDB.Where(
DB.Scopes(NameIn1And2),
).Or(
DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"),
).Find(&Pizza{}).Statement
execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement
result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...)
expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...)
if !strings.HasSuffix(result2, expects2) {
t.Errorf("expects: %v, got %v", expects2, result2)
}
} }
func TestCombineStringConditions(t *testing.T) { func TestCombineStringConditions(t *testing.T) {
@ -307,7 +375,7 @@ func TestToSQL(t *testing.T) {
}) })
assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql)
// after model chagned // after model changed
if DB.Statement.DryRun || DB.DryRun { if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
} }
@ -373,13 +441,13 @@ func TestToSQL(t *testing.T) {
}) })
assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql)
// after model chagned // after model changed
if DB.Statement.DryRun || DB.DryRun { if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
} }
} }
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
func assertEqualSQL(t *testing.T, expected string, actually string) { func assertEqualSQL(t *testing.T, expected string, actually string) {
t.Helper() t.Helper()
@ -387,7 +455,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
expected = replaceQuoteInSQL(expected) expected = replaceQuoteInSQL(expected)
actually = replaceQuoteInSQL(actually) actually = replaceQuoteInSQL(actually)
// ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update.
updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`)
actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`)
expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`)
@ -407,16 +475,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
func replaceQuoteInSQL(sql string) string { func replaceQuoteInSQL(sql string) string {
// convert single quote into double quote // convert single quote into double quote
sql = strings.Replace(sql, `'`, `"`, -1) sql = strings.ReplaceAll(sql, `'`, `"`)
// convert dialect speical quote into double quote // convert dialect special quote into double quote
switch DB.Dialector.Name() { switch DB.Dialector.Name() {
case "postgres": case "postgres":
sql = strings.Replace(sql, `"`, `"`, -1) sql = strings.ReplaceAll(sql, `"`, `"`)
case "mysql", "sqlite": case "mysql", "sqlite":
sql = strings.Replace(sql, "`", `"`, -1) sql = strings.ReplaceAll(sql, "`", `"`)
case "sqlserver": case "sqlserver":
sql = strings.Replace(sql, `'`, `"`, -1) sql = strings.ReplaceAll(sql, `'`, `"`)
} }
return sql return sql

View File

@ -15,6 +15,25 @@ then
cd .. cd ..
fi fi
# SqlServer for Mac M1
if [[ -z $GITHUB_ACTION ]]; then
if [ -d tests ]
then
cd tests
if [[ $(uname -a) == *" arm64" ]]; then
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
else
docker-compose start
fi
cd ..
fi
fi
for dialect in "${dialects[@]}" ; do for dialect in "${dialects[@]}" ; do
if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ]
then then

View File

@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) {
PreferSimpleProtocol: true, PreferSimpleProtocol: true,
}), &gorm.Config{}) }), &gorm.Config{})
case "sqlserver": case "sqlserver":
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
// CREATE DATABASE gorm; // CREATE DATABASE gorm;
// USE gorm; // GO
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// CREATE USER gorm FROM LOGIN gorm; // CREATE USER gorm FROM LOGIN gorm;
// sp_changedbowner 'gorm'; // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];
// npm install -g sql-cli // GO
// mssql -u gorm -p LoremIpsum86 -d gorm -o 9930
log.Println("testing sqlserver...") log.Println("testing sqlserver...")
if dbDSN == "" { if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
@ -94,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
func RunMigrations() { func RunMigrations() {
var err error var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

View File

@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true}) dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Save(&user).Statement stmt := dryDB.Save(&user).Statement
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
} }

View File

@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) {
tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user)
if err := tx.Error; err != nil { if err := tx.Error; err != nil {
t.Fatalf("failed to update user,missing where condtion,err=%+v", err) t.Fatalf("failed to update user,missing where condition,err=%+v", err)
} }
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) {

View File

@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
shiftDelimiter = 0 shiftDelimiter = 0
underQuoted = false underQuoted = false
continuousBacktick = 0 continuousBacktick = 0
writer.WriteString("`") writer.WriteByte('`')
} }
writer.WriteByte(v) writer.WriteByte(v)
continue continue
@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
if continuousBacktick > 0 && !selfQuoted { if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``") writer.WriteString("``")
} }
writer.WriteString("`") writer.WriteByte('`')
} }
func (DummyDialector) Explain(sql string, vars ...interface{}) string { func (DummyDialector) Explain(sql string, vars ...interface{}) string {

View File

@ -80,3 +80,17 @@ type Order struct {
Coupon *Coupon Coupon *Coupon
CouponID string CouponID string
} }
type Parent struct {
gorm.Model
FavChildID uint
FavChild *Child
Children []*Child
}
type Child struct {
gorm.Model
Name string
ParentID *uint
Parent *Parent
}

View File

@ -83,20 +83,22 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
} }
if reflect.ValueOf(got).Kind() == reflect.Struct { if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { if reflect.ValueOf(expect).Kind() == reflect.Struct {
exported := false if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
for i := 0; i < reflect.ValueOf(got).NumField(); i++ { exported := false
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
exported = true if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
field := reflect.ValueOf(got).Field(i) exported = true
t.Run(fieldStruct.Name, func(t *testing.T) { field := reflect.ValueOf(got).Field(i)
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) t.Run(fieldStruct.Name, func(t *testing.T) {
}) AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
} }
}
if exported { if exported {
return return
}
} }
} }
} }
@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual() isEqual()
} else {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
} }
} }
} }

View File

@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
} }
func CheckTruth(val interface{}) bool { // CheckTruth check string true or not
if v, ok := val.(bool); ok { func CheckTruth(vals ...string) bool {
return v for _, val := range vals {
if val != "" && !strings.EqualFold(val, "false") {
return true
}
} }
return false
if v, ok := val.(string); ok {
v = strings.ToLower(v)
return v != "false"
}
return !reflect.ValueOf(val).IsZero()
} }
func ToStringKey(values ...interface{}) string { func ToStringKey(values ...interface{}) string {