Merge branch 'go-gorm:master' into master

This commit is contained in:
piyongcai 2022-05-09 16:15:34 +08:00 committed by GitHub
commit e0030749b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 2728 additions and 761 deletions

View File

@ -3,20 +3,26 @@ on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v4
uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:invalid question"

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: labeler
uses: jinzhu/super-labeler-action@develop

View File

@ -3,19 +3,25 @@ on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v4
uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 2
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:missing reproduction steps"

View File

@ -6,7 +6,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2

View File

@ -3,19 +3,25 @@ on:
schedule:
- cron: "0 2 * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v4
uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
days-before-stale: 60
days-before-close: 30
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
days-before-stale: 360
days-before-close: 180
stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale'

View File

@ -8,38 +8,41 @@ on:
branches-ignore:
- 'gh-pages'
permissions:
contents: read
jobs:
# Label of the container job
sqlite:
strategy:
matrix:
go: ['1.17', '1.16']
go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlite ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.17', '1.16']
go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -62,28 +65,28 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.17', '1.16']
go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
@ -106,26 +109,26 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver:
strategy:
matrix:
go: ['1.17', '1.16']
go: ['1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
@ -149,18 +152,18 @@ jobs:
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ documents
coverage.txt
_book
.idea
vendor

View File

@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
## Contributing

View File

@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error {
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
@ -186,11 +186,14 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.BelongsTo:
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -198,11 +201,14 @@ func (association *Association) Delete(values ...interface{}) error {
case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -228,11 +234,14 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -241,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
@ -253,7 +262,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -261,23 +270,23 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
association.Error = rel.Field.Set(data, validFieldValues.Interface())
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(fieldValue)
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
if rel.JoinTable == nil {
for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else {
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -329,14 +338,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
}
}
case reflect.Struct:
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -344,7 +353,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
if clear {
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
}
@ -373,7 +382,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
}
}
}
@ -421,7 +430,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
@ -429,7 +438,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
@ -453,12 +462,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -475,7 +484,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
@ -486,14 +495,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}

View File

@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
sortCallback func(*callback) error
)
sort.Slice(cs, func(i, j int) bool {
return cs[j].before == "*" || cs[j].after == "*"
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs {

View File

@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(obj, pv))
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
break
}
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
@ -69,20 +69,20 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
}
if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(rv, fv))
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
}
}
@ -146,11 +146,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
@ -158,15 +158,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(f, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(elem); !ok {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
}
}
joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
}
for i := 0; i < elemLen; i++ {
@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/utils"
)
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) {
}
}
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
@ -82,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
db.AddError(rows.Close())
}
return
@ -117,9 +121,9 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
@ -130,38 +134,39 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
}
}
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
@ -201,13 +206,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len()
stmt.SQL.Grow(rValLen * 18)
values.Values = make([][]interface{}, rValLen)
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
@ -219,23 +226,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
field.Set(rv, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
}
@ -259,23 +266,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
}

View File

@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
@ -118,12 +118,18 @@ func Delete(config *Config) func(db *gorm.DB) {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -131,7 +137,7 @@ func Delete(config *Config) func(db *gorm.DB) {
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
}
db.Statement.AddClauseIfNotExists(clause.From{})
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)

View File

@ -1,6 +1,7 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
@ -104,3 +105,48 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

View File

@ -10,10 +10,9 @@ import (
"gorm.io/gorm/utils"
)
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = db.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{}
)
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return
return nil
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(joinIndexValue)
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 {
return
return nil
}
}
@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
}
fieldValues := make([]interface{}, len(relForeignFields))
@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(elem)
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists",
elem.Interface()))
continue
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(data)
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
rel.Field.Set(data, elem.Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
}
}
}
return tx.Error
}

View File

@ -20,8 +20,10 @@ func Query(db *gorm.DB) {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
db.AddError(rows.Close())
}
}
}
@ -40,7 +42,7 @@ func BuildQuerySQL(db *gorm.DB) {
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
@ -94,13 +96,13 @@ func BuildQuerySQL(db *gorm.DB) {
}
// inline joins
joins := []clause.Join{}
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
joins = fromClause.Joins
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Joins) != 0 || len(joins) != 0 {
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
@ -109,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) {
for _, join := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
@ -145,35 +147,49 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
if join.On != nil {
onStmt := gorm.Statement{Table: tableAliasName, DB: db}
join.On.Build(&onStmt)
onSQL := onStmt.SQL.String()
vars := onStmt.Vars
for idx, v := range onStmt.Vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
joins = append(joins, clause.Join{
fromClause.Joins = append(fromClause.Joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else {
joins = append(joins, clause.Join{
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
db.Statement.AddClause(clause.From{Joins: joins})
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
@ -186,6 +202,11 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
@ -218,9 +239,20 @@ func Preload(db *gorm.DB) {
}
sort.Strings(preloadNames)
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
for _, name := range preloadNames {
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}

View File

@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
}
}
}
@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
}
}
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) {
}
}
// Update update hook
func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
@ -59,6 +61,12 @@ func Update(config *Config) func(db *gorm.DB) {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
@ -68,22 +76,10 @@ func Update(config *Config) func(db *gorm.DB) {
return
}
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.Build(db.Statement.BuildClauses...)
}
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
@ -105,9 +101,17 @@ func Update(config *Config) func(db *gorm.DB) {
}
}
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
@ -115,12 +119,6 @@ func AfterUpdate(db *gorm.DB) {
}
}
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
return called
})
}
@ -137,13 +135,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.ReflectValue, value)
field.Set(stmt.Context, stmt.ReflectValue, value)
}
}
default:
@ -165,7 +163,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
@ -178,7 +176,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
@ -232,10 +230,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.GORMDataType == schema.Time {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
}
@ -258,16 +256,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(updatingValue)
value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc()
} else {
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
}
isZero = false
}
@ -278,7 +276,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}

View File

@ -0,0 +1,36 @@
package callbacks
import (
"reflect"
"testing"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

View File

@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
} else {
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
return
}
@ -90,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
}
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
@ -120,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
}
}
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))

View File

@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) {
}
if limit.Offset > 0 {
if limit.Limit > 0 {
builder.WriteString(" ")
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))

View File

@ -43,6 +43,23 @@ func TestSelect(t *testing.T) {
}, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
},
}
for idx, result := range results {

View File

@ -4,6 +4,11 @@ import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause
type Where struct {
Exprs []Expression
@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) {
}
}
buildExprs(where.Exprs, builder, " AND ")
buildExprs(where.Exprs, builder, AndWithSpace)
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ")
builder.WriteString(OrWithSpace)
} else {
builder.WriteString(joinCond)
}
@ -46,30 +51,30 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteString(`(`)
builder.WriteByte('(')
expr.Build(builder)
builder.WriteString(`)`)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder)
@ -110,10 +115,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, " AND ")
buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, " AND ")
buildExprs(and.Exprs, builder, AndWithSpace)
}
}
@ -131,10 +136,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, " OR ")
buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, " OR ")
buildExprs(or.Exprs, builder, OrWithSpace)
}
}
@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) {
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}

View File

@ -66,6 +66,45 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
[]interface{}{18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
}
for idx, result := range results {

View File

@ -39,4 +39,6 @@ var (
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
)

View File

@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
@ -83,7 +87,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
}
}
@ -101,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) {
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
return tx.Create(value)
}
}
@ -177,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
totalSize = limit.Limit
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
@ -192,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
@ -199,7 +227,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break
}
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
@ -207,7 +235,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return tx
}
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
@ -215,40 +243,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
if field := db.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
tx.assignInterfacesToValue(andCond.Exprs)
db.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
tx.assignInterfacesToValue(exprs)
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
}
@ -256,12 +284,13 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
}
}
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
@ -281,26 +310,28 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return
}
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.Error == nil {
if tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if result := queryTx.Find(dest, conds...); result.Error == nil {
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
result.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}
return tx.Create(dest)
@ -320,6 +351,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
}
return tx.Model(dest).Updates(assigns)
} else {
tx.Error = result.Error
}
}
return tx
@ -585,7 +618,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
@ -594,11 +627,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0]
}
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
default:
err = ErrInvalidTransaction
}

17
gorm.go
View File

@ -59,6 +59,7 @@ type Config struct {
cacheStore *sync.Map
}
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error {
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error {
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
@ -96,6 +99,7 @@ type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
@ -120,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
for _, opt := range opts {
if opt != nil {
if err := opt.Apply(config); err != nil {
return nil, err
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
@ -282,6 +286,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc
}
if config.Initialized {
tx = tx.getInstance()
}
return tx
}
@ -376,10 +384,12 @@ func (db *DB) getInstance() *DB {
return db
}
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
@ -430,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
@ -451,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true}))
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -40,24 +40,45 @@ type SavePointerDialectorInterface interface {
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
Commit() error
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}
// GetDBConnector SQL db connector
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
}

View File

@ -12,6 +12,7 @@ import (
"gorm.io/gorm/utils"
)
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors
@ -30,13 +31,17 @@ const (
YellowBold = "\033[33;1m"
)
// LogLevel
// LogLevel log level
type LogLevel int
const (
// Silent silent log level
Silent LogLevel = iota + 1
// Error error log level
Error
// Warn warn log level
Warn
// Info info log level
Info
)
@ -45,6 +50,7 @@ type Writer interface {
Printf(string, ...interface{})
}
// Config logger config
type Config struct {
SlowThreshold time.Duration
Colorful bool
@ -62,16 +68,20 @@ type Interface interface {
}
var (
// Discard Discard logger will print any log to ioutil.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
)
// New initialize logger
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
@ -179,10 +189,12 @@ type traceRecorder struct {
Err error
}
// New new trace recorder
func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()

View File

@ -19,9 +19,9 @@ const (
nullStr = "NULL"
)
func isPrintable(s []byte) bool {
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(rune(r)) {
if !unicode.IsPrint(r) {
return false
}
}
@ -30,9 +30,12 @@ func isPrintable(s []byte) bool {
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var convertParams func(interface{}, int)
vars := make([]string, len(avars))
var (
convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {
@ -64,14 +67,25 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
} else {
vars[idx] = nullStr
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
@ -80,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case string:
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -97,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return
}
}
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
}
}
}

View File

@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
}
func format(v []byte, escaper string) string {
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
}
func TestExplainSQL(t *testing.T) {

View File

@ -1,6 +1,8 @@
package gorm
import (
"reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
@ -33,14 +35,23 @@ type ViewOption struct {
Query *DB
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...interface{}) error

107
migrator/column_type.go Normal file
View File

@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

View File

@ -30,10 +30,12 @@ type Config struct {
gorm.Dialector
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string
}
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
return fc(stmt)
}
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
return m.Dialector.DataTypeOf(field)
}
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return
}
// AutoMigrate
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{})
@ -95,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
if err != nil {
return err
}
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
@ -156,12 +163,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return nil
}
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
@ -252,6 +261,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
return nil
}
// DropTable drop table for values
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
@ -265,6 +275,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
return nil
}
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value interface{}) bool {
var count int64
@ -276,6 +287,7 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0
}
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{}
if v, ok := oldName.(string); ok {
@ -303,12 +315,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
}
func (m Migrator) AddColumn(value interface{}, field string) error {
// AddColumn create `name` column for value
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(field)
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", field)
return fmt.Errorf("failed to look up field with name: %s", name)
}
if !f.IgnoreMigration {
@ -322,6 +335,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
})
}
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
@ -334,10 +348,11 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
@ -348,6 +363,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
})
}
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -366,6 +382,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0
}
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
@ -383,6 +400,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
})
}
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
// found, smart migrate
fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
@ -421,6 +439,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value
if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name)
}
@ -448,7 +490,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, c)
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
}
return
@ -457,10 +499,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}
// CreateView create view
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
return gorm.ErrNotImplemented
}
// DropView drop view
func (m Migrator) DropView(name string) error {
return gorm.ErrNotImplemented
}
@ -487,6 +531,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
return
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil {
return nil, nil, stmt.Table
@ -531,6 +576,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
return nil, nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -554,6 +600,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
})
}
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
@ -566,6 +613,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
})
}
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -586,6 +634,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
@ -607,10 +656,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return
}
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
}
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -642,6 +693,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
})
}
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
@ -652,6 +704,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
})
}
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@ -669,6 +722,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0
}
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
@ -678,6 +732,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
})
}
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return
@ -704,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
Statement: &gorm.Statement{DB: m.DB, Dest: value},
}
beDependedOn := map[*schema.Schema]bool{}
if err := dep.Parse(value); err != nil {
// support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
@ -781,6 +837,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
return
}
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil {
return *stmt.TableExpr

View File

@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
}
type PreparedStmtTX struct {
*sql.Tx
Tx
PreparedStmtDB *PreparedStmtDB
}
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

147
scan.go
View File

@ -10,6 +10,7 @@ import (
"gorm.io/gorm/schema"
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
@ -49,65 +50,56 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
}
}
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, column := range columns {
if sch == nil {
values[idx] = reflectValue.Interface()
} else if field := sch.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
sch = nil
values[idx] = reflectValue.Interface()
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
field.Set(reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
for idx, field := range fields {
if field != nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else {
relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
@ -138,7 +130,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}, []map[string]interface{}:
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
@ -149,11 +141,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
if values, ok := dest.([]map[string]interface{}); ok {
values = append(values, mapValue)
} else if values, ok := dest.(*[]map[string]interface{}); ok {
*values = append(*values, mapValue)
}
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
@ -168,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
default:
var (
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
fields = make([]*schema.Field, len(columns))
selectedColumnsMap = make(map[string]int, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
@ -193,35 +182,49 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if len(columns) == 1 {
// isPluck
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok {
for fieldIndex, selectField := range sch.Fields[curIndex+1:] {
if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1
fields[idx] = selectField
break
}
}
} else {
fields[idx] = field
selectedColumnsMap[column] = idx
}
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
}
switch reflectValue.Kind() {
@ -244,7 +247,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(elem); !ok {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++
goto BEGIN
}
@ -254,7 +257,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if isPtr {
@ -270,7 +273,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))

View File

@ -1,6 +1,7 @@
package schema
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
@ -11,15 +12,25 @@ import (
"time"
"github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
type DataType string
// special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type TimeType int64
var TimeReflectType = reflect.TypeOf(time.Time{})
type (
// DataType GORM data type
DataType string
// TimeType GORM time type
TimeType int64
)
// GORM time types
const (
UnixTime TimeType = 1
UnixSecond TimeType = 2
@ -27,6 +38,7 @@ const (
UnixNanosecond TimeType = 4
)
// GORM fields types
const (
Bool DataType = "bool"
Int DataType = "int"
@ -37,6 +49,7 @@ const (
Bytes DataType = "bytes"
)
// Field is the representation of model schema's field
type Field struct {
Name string
DBName string
@ -49,9 +62,9 @@ type Field struct {
Creatable bool
Updatable bool
Readable bool
HasDefaultValue bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string
DefaultValueInterface interface{}
NotNull bool
@ -60,6 +73,7 @@ type Field struct {
Size int
Precision int
Scale int
IgnoreMigration bool
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
@ -68,27 +82,39 @@ type Field struct {
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
IgnoreMigration bool
ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool
}
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true,
Updatable: true,
Readable: true,
Tag: fieldStruct.Tag,
TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"),
Schema: schema,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1,
}
@ -97,7 +123,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type
// if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
@ -105,31 +131,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v)
}
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
rv := reflect.Indirect(v)
if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) {
for i := 0; i < rv.Type().NumField(); i++ {
newFieldType := rv.Type().Field(i).Type
var (
rv = reflect.Indirect(v)
rvType = rv.Type()
)
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
if rv.Type() != reflect.Indirect(fieldValue).Type() {
if rvType != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
}
}
@ -138,19 +170,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if dbName, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = dbName
}
if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true
} else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
field.PrimaryKey = true
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
field.AutoIncrement = true
field.HasDefaultValue = true
if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DataType = String
field.Serializer = v
} else {
var serializerName = field.TagSettings["JSON"]
if serializerName == "" {
serializerName = field.TagSettings["SERIALIZER"]
}
if serializerName != "" {
if serializer, ok := GetSerializer(serializerName); ok {
// Set default data type to string for serializer
field.DataType = String
field.Serializer = serializer
} else {
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
}
}
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
@ -176,20 +212,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.Scale, _ = strconv.Atoi(s)
}
if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
} else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) {
field.NotNull = true
}
if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) {
field.Unique = true
}
if val, ok := field.TagSettings["COMMENT"]; ok {
field.Comment = val
}
// default value is function or null or blank (primary keys)
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
@ -225,7 +247,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
case reflect.String:
field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
@ -236,22 +257,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time
}
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
if t, err := now.Parse(field.DefaultValue); err == nil {
field.DefaultValueInterface = t
}
}
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) {
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
field.DataType = Bytes
}
}
field.GORMDataType = field.DataType
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType())
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
@ -263,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
@ -275,6 +299,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) {
case Bool, Int, Uint, Float, String, Time, Bytes:
@ -284,10 +312,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if field.Size == 0 {
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
@ -346,8 +370,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes &&
(ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) {
// Normal anonymous field or having `EMBEDDED` tag
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind()
switch kind {
case reflect.Struct:
@ -410,31 +435,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
// ValueOf
// Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero()
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value)
for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-idx - 1)
if v.Type().Elem().Kind() != reflect.Struct {
return nil, true
}
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
@ -443,35 +462,52 @@ func (field *Field) setupValuerAndSetter() {
}
}
}
return v.Interface(), v.IsZero()
fv, zero := v.Interface(), v.IsZero()
return fv, zero
}
}
// ReflectValueOf
switch {
case len(field.StructField.Index) == 1:
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0])
if field.Serializer != nil {
oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValuerOf(ctx, v)
if zero {
return value, zero
}
s, ok := value.(SerializerValuerInterface)
if !ok {
s = field.Serializer
}
return &serializer{
Field: field,
SerializeValuer: s,
Destination: v,
Context: ctx,
fieldValue: value,
}, false
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
}
// ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
}
default:
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
v := reflect.Indirect(value)
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
}
if v.Kind() == reflect.Ptr {
if v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if idx < len(field.StructField.Index)-1 {
@ -483,22 +519,25 @@ func (field *Field) setupValuerAndSetter() {
}
}
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
// Optimal value type acquisition for v
reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
reflectV = reflect.Indirect(reflectV)
}
field.ReflectValueOf(ctx, value).Set(reflectV)
return
} else if reflectValType.ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
return
} else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value)
fieldValue := field.ReflectValueOf(ctx, value)
fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(fieldType) {
@ -521,16 +560,19 @@ func (field *Field) setupValuerAndSetter() {
if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
return
} else {
err = setter(value, reflectV.Elem().Interface())
err = setter(ctx, value, reflectV.Elem().Interface())
}
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = setter(value, v)
err = setter(ctx, value, v)
}
} else {
return fmt.Errorf("failed to set value %+v to field %s", v, field.Name)
} else if _, ok := v.(clause.Expr); !ok {
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
}
}
@ -540,191 +582,201 @@ func (field *Field) setupValuerAndSetter() {
// Set
switch field.FieldType.Kind() {
case reflect.Bool:
field.Set = func(value reflect.Value, v interface{}) error {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **bool:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data)
}
case bool:
field.ReflectValueOf(value).SetBool(data)
case *bool:
if data != nil {
field.ReflectValueOf(value).SetBool(*data)
} else {
field.ReflectValueOf(value).SetBool(false)
}
field.ReflectValueOf(ctx, value).SetBool(data)
case int64:
if data > 0 {
field.ReflectValueOf(value).SetBool(true)
} else {
field.ReflectValueOf(value).SetBool(false)
}
field.ReflectValueOf(ctx, value).SetBool(data > 0)
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b)
field.ReflectValueOf(ctx, value).SetBool(b)
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **int64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data)
}
case int64:
field.ReflectValueOf(value).SetInt(data)
field.ReflectValueOf(ctx, value).SetInt(data)
case int:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int8:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int16:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int32:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint8:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint16:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint32:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint64:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float32:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float64:
field.ReflectValueOf(value).SetInt(int64(data))
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case []byte:
return field.Set(value, string(data))
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetInt(i)
field.ReflectValueOf(ctx, value).SetInt(i)
} else {
return err
}
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6)
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(value).SetInt(0)
field.ReflectValueOf(ctx, value).SetInt(0)
}
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **uint64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data)
}
case uint64:
field.ReflectValueOf(value).SetUint(data)
field.ReflectValueOf(ctx, value).SetUint(data)
case uint:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint8:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint16:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint32:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int64:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int8:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int16:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int32:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float32:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float64:
field.ReflectValueOf(value).SetUint(uint64(data))
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case []byte:
return field.Set(value, string(data))
return field.Set(ctx, value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano()))
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6))
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(value).SetUint(uint64(data.Unix()))
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
}
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(value).SetUint(i)
field.ReflectValueOf(ctx, value).SetUint(i)
} else {
return err
}
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Float32, reflect.Float64:
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **float64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data)
}
case float64:
field.ReflectValueOf(value).SetFloat(data)
field.ReflectValueOf(ctx, value).SetFloat(data)
case float32:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int64:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int8:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int16:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int32:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint8:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint16:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint32:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint64:
field.ReflectValueOf(value).SetFloat(float64(data))
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case []byte:
return field.Set(value, string(data))
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(value).SetFloat(i)
field.ReflectValueOf(ctx, value).SetFloat(i)
} else {
return err
}
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.String:
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **string:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data)
}
case string:
field.ReflectValueOf(value).SetString(data)
field.ReflectValueOf(ctx, value).SetString(data)
case []byte:
field.ReflectValueOf(value).SetString(string(data))
field.ReflectValueOf(ctx, value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(value).SetString(utils.ToString(data))
field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
@ -732,41 +784,49 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
case time.Time:
field.Set = func(value reflect.Value, v interface{}) error {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.Set(ctx, value, *data)
}
case time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case *time.Time:
if data != nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
} else {
field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{}))
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
}
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(value).Set(reflect.ValueOf(t))
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case *time.Time:
field.Set = func(value reflect.Value, v interface{}) error {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **time.Time:
if data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
}
case time.Time:
fieldValue := field.ReflectValueOf(value)
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case string:
if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(value)
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
if v == "" {
return nil
@ -778,27 +838,27 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(value, v, field.Set)
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else {
fieldValue := field.ReflectValueOf(value)
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
@ -813,32 +873,80 @@ func (field *Field) setupValuerAndSetter() {
}
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
return field.Set(ctx, value, reflectV.Elem().Interface())
}
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set)
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
return fallbackSetter(ctx, value, v, field.Set)
}
}
}
}
if field.Serializer != nil {
var (
oldFieldSetter = field.Set
sameElemType bool
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
)
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
}
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
if s, ok := v.(*serializer); ok {
if s.fieldValue != nil {
err = oldFieldSetter(ctx, value, s.fieldValue)
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
} else if sameType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
}
}
} else {
err = oldFieldSetter(ctx, value, v)
}
return
}
}
}
func (field *Field) setupNewValuePool() {
if field.Serializer != nil {
field.NewValuePool = &sync.Pool{
New: func() interface{} {
return &serializer{
Field: field,
Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface),
}
},
}
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
}

View File

@ -1,6 +1,7 @@
package schema_test
import (
"context"
"database/sql"
"reflect"
"sync"
@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}

View File

@ -1,6 +1,7 @@
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index {
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range parseFieldIndexes(field) {
fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
@ -82,18 +88,19 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil
}
func parseFieldIndexes(field *Field) (indexes []Index) {
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
settings = ParseTagSetting(tag, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
@ -105,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
}
if name == "" {
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"The composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
@ -137,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) {
}
}
err = nil
return
}

View File

@ -18,6 +18,37 @@ type UserIndex struct {
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
}
func TestParseIndex(t *testing.T) {
@ -78,6 +109,41 @@ func TestParseIndex(t *testing.T) {
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
},
"type": {
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
"idx_user_indices_comp_id0": {
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
"idx_user_indices_comp_id1": {
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
"idx_user_indices_comp_id2": {
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
},
}
indices := user.ParseIndexes()

View File

@ -4,22 +4,33 @@ import (
"gorm.io/gorm/clause"
)
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}

View File

@ -3,7 +3,6 @@ package schema
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"regexp"
"strings"
"unicode/utf8"
@ -86,16 +85,16 @@ func (ns NamingStrategy) IndexName(table, column string) string {
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.Replace(strings.Join([]string{
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_", -1)
}, "_"), ".", "_")
if utf8.RuneCountInString(formattedName) > 64 {
h := sha1.New()
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8]
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
}
return formattedName
}
@ -120,7 +119,13 @@ func (ns NamingStrategy) toDBName(name string) string {
}
if ns.NameReplacer != nil {
name = ns.NameReplacer.Replace(name)
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
@ -168,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string {
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}

View File

@ -193,7 +193,18 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" {
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

19
schema/pool.go Normal file
View File

@ -0,0 +1,19 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,6 +1,7 @@
package schema
import (
"context"
"fmt"
"reflect"
"strings"
@ -234,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
@ -257,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
@ -415,6 +418,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
} else {
var primaryFields []*Field
var primarySchemaName = primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
@ -427,7 +434,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
for _, primaryField := range primaryFields {
lookUpName := primarySchema.Name + primaryField.Name
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}
@ -576,7 +583,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
return &constraint
}
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}
@ -616,7 +623,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
}
}
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})

View File

@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) {
}
}
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) {
type UserAux struct {
gorm.Model
@ -576,3 +596,39 @@ func TestHasManySameForeignKey(t *testing.T) {
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
})
}
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -1,6 +1,7 @@
package schema_test
import (
"context"
"fmt"
"reflect"
"strings"
@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
fv, _ := s.FieldsByDBName[k].ValueOf(value)
fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value)
tests.AssertEqual(t, v, fv)
})
}

157
schema/serializer.go Normal file
View File

@ -0,0 +1,157 @@
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value interface{}
fieldValue interface{}
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value interface{}) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}
// JSONSerializer json serializer
type JSONSerializer struct {
}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
err = json.Unmarshal(bytes, fieldValue.Interface())
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct {
}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct {
}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}

View File

@ -1,6 +1,8 @@
package schema
import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
@ -58,14 +60,22 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
return tag
}
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr())
@ -97,7 +107,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle
}
// GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var (
results = [][]interface{}{}
dataResults = map[string][]reflect.Value{}
@ -110,7 +120,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields {
results[0][idx], zero = field.ValueOf(reflectValue)
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero
}
@ -135,7 +145,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
fieldValues := make([]interface{}, len(fields))
notZero = false
for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(elem)
fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero
}
@ -155,12 +165,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
}
// GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{}
for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields)
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...)
}

View File

@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
}
@ -135,7 +133,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -143,7 +141,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
}
}
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}

View File

@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
writer.WriteByte(',')
}
stmt.QuoteTo(writer, d)
}
@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
writer.WriteByte(',')
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
@ -179,9 +179,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Expr:
v.Build(stmt)
case *clause.Expr:
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
@ -314,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression:
conds = append(conds, v)
case *DB:
for _, scope := range v.Statement.scopes {
v = scope(v)
}
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
@ -391,7 +393,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
@ -405,7 +407,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
@ -564,7 +566,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
switch destValue.Kind() {
case reflect.Struct:
field.Set(destValue, value)
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
@ -574,10 +576,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
field.Set(stmt.ReflectValue.Index(i), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
if !stmt.ReflectValue.CanAddr() {
@ -585,7 +587,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
return
}
field.Set(stmt.ReflectValue, value)
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
@ -605,12 +607,12 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(modelValue)
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok {
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := v[field.DBName]; ok {
} else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
@ -619,7 +621,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(destValue)
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}

View File

@ -220,3 +220,67 @@ func TestFullSaveAssociations(t *testing.T) {
t.Errorf("Failed to preload AppliesToProduct")
}
}
func TestSaveBelongsCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent}
DB.Create(&child)
parent.FavChildID = child.ID
parent.FavChild = &child
DB.Save(&parent)
var parent1 Parent
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
// Save and Updates is the same
DB.Updates(&parent)
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
}
func TestSaveHasManyCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
parent.Children = []*Child{&child, &child1}
DB.Save(&parent)
var children []*Child
DB.Where("parent_id = ?", parent.ID).Find(&children)
if len(children) != len(parent.Children) ||
children[0].ID != parent.Children[0].ID ||
children[1].ID != parent.Children[1].ID {
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
children, parent.Children)
}
}
func TestAssociationError(t *testing.T) {
user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2})
DB.Create(&user)
var user1 User
DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1)
var emptyUser User
var err error
// belongs to
err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// has many
err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// has one
err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
// many to many
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
}

View File

@ -38,6 +38,7 @@ func c2(*gorm.DB) {}
func c3(*gorm.DB) {}
func c4(*gorm.DB) {}
func c5(*gorm.DB) {}
func c6(*gorm.DB) {}
func TestCallbacks(t *testing.T) {
type callback struct {
@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) {
}
}
}
func TestPluginCallbacks(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("plugin_1_fn1", c1)
createCallback.After("*").Register("plugin_1_fn2", c2)
if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 2
createCallback.Before("*").Register("plugin_2_fn1", c3)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_2_fn2", c4)
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
// plugin 3
createCallback.Before("*").Register("plugin_3_fn1", c5)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.After("*").Register("plugin_3_fn2", c6)
if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

View File

@ -9,7 +9,7 @@ import (
)
func TestWithSingleConnection(t *testing.T) {
var expectedName = "test"
expectedName := "test"
var actualName string
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) {
}
return nil
})
if err != nil {
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
}

171
tests/connpool_test.go Normal file
View File

@ -0,0 +1,171 @@
package tests_test
import (
"context"
"database/sql"
"os"
"reflect"
"testing"
"gorm.io/driver/mysql"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
type wrapperTx struct {
*sql.Tx
conn *wrapperConnPool
}
func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.PrepareContext(ctx, query)
}
func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.ExecContext(ctx, query, args...)
}
func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryContext(ctx, query, args...)
}
func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.conn.got = append(c.conn.got, query)
return c.Tx.QueryRowContext(ctx, query, args...)
}
type wrapperConnPool struct {
db *sql.DB
got []string
expect []string
}
func (c *wrapperConnPool) Ping() error {
return c.db.Ping()
}
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
// return c.db.BeginTx(ctx, opts)
// }
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
tx, err := c.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &wrapperTx{Tx: tx, conn: c}, nil
}
func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
c.got = append(c.got, query)
return c.db.PrepareContext(ctx, query)
}
func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.got = append(c.got, query)
return c.db.ExecContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
c.got = append(c.got, query)
return c.db.QueryContext(ctx, query, args...)
}
func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
c.got = append(c.got, query)
return c.db.QueryRowContext(ctx, query, args...)
}
func TestConnPoolWrapper(t *testing.T) {
dialect := os.Getenv("GORM_DIALECT")
if dialect != "mysql" {
t.SkipNow()
}
dbDSN := os.Getenv("GORM_DSN")
if dbDSN == "" {
dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
}
nativeDB, err := sql.Open("mysql", dbDSN)
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
conn := &wrapperConnPool{
db: nativeDB,
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
},
}
defer func() {
if !reflect.DeepEqual(conn.got, conn.expect) {
t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
}
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}
tx := db.Begin()
user := *GetUser("transaction", Config{})
if err = tx.Save(&user).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
user1 := *GetUser("transaction1-1", Config{})
if err = tx.Save(&user1).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
t.Fatalf("Should not find record after rollback, but got %v", err)
}
txDB := db.Where("fake_name = ?", "fake_name")
tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
user2 := *GetUser("transaction-2", Config{})
if err = tx2.Save(&user2).Error; err != nil {
t.Fatalf("No error should raise, but got %v", err)
}
if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should find saved record, but got %v", err)
}
tx2.Commit()
if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
t.Fatalf("Should be able to find committed record, but got %v", err)
}
}

View File

@ -144,4 +144,22 @@ func TestCount(t *testing.T) {
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
}
var count12 int64
if err := DB.Table("users").
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name")
}).Count(&count12).Error; err == nil {
t.Errorf("error should raise when using preload without schema")
}
var count13 int64
if err := DB.Model(User{}).
Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).
Preload("Toys", func(db *gorm.DB) *gorm.DB {
return db.Table("toys").Select("name")
}).Count(&count13).Error; err != nil {
t.Errorf("no error should raise when using count with preload, but got %v", err)
}
}

View File

@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
{"name": "create_from_map_3", "Age": 20},
}
if err := DB.Model(&User{}).Create(datas).Error; err != nil {
if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) {
t.Fatalf("it is not ErrInvalidValue")
}
}
func TestFirstOrCreateRowsAffected(t *testing.T) {
user := User{Name: "TestFirstOrCreateRowsAffected"}
res := DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 1 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
res = DB.FirstOrCreate(&user, "name = ?", user.Name)
if res.Error != nil || res.RowsAffected != 0 {
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
}
}

View File

@ -2,6 +2,7 @@ package tests_test
import (
"testing"
"time"
"gorm.io/gorm"
)
@ -9,12 +10,13 @@ import (
func TestDefaultValue(t *testing.T) {
type Harumph struct {
gorm.Model
Email string `gorm:"not null;index:,unique"`
Name string `gorm:"notNull;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;notNull;default:''"`
Age int `gorm:"default:18"`
Enabled bool `gorm:"default:true"`
Email string `gorm:"not null;index:,unique"`
Name string `gorm:"notNull;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;notNull;default:''"`
Age int `gorm:"default:18"`
Created time.Time `gorm:"default:2000-01-02"`
Enabled bool `gorm:"default:true"`
}
DB.Migrator().DropTable(&Harumph{})
@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) {
harumph := Harumph{Email: "hello@gorm.io"}
if err := DB.Create(&harumph).Error; err != nil {
t.Fatalf("Failed to create data with default value, got error: %v", err)
} else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled {
} else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to create data with default value, got: %+v", harumph)
}
var result Harumph
if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil {
t.Fatalf("Failed to find created data, got error: %v", err)
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled {
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to find created data with default data, got %+v", result)
}
}

View File

@ -2,7 +2,7 @@ version: '3'
services:
mysql:
image: 'mysql:latest'
image: 'mysql/mysql-server:latest'
ports:
- 9910:3306
environment:
@ -20,7 +20,7 @@ services:
- POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm
mssql:
image: 'mcmoe/mssqldocker:latest'
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
ports:
- 9930:1433
environment:

View File

@ -3,16 +3,16 @@ module gorm.io/gorm/tests
go 1.14
require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.14.1 // indirect
github.com/jinzhu/now v1.1.4
github.com/lib/pq v1.10.4
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect
gorm.io/driver/mysql v1.2.3
gorm.io/driver/postgres v1.2.3
gorm.io/driver/sqlite v1.2.6
gorm.io/driver/sqlserver v1.2.1
gorm.io/gorm v1.22.4
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.5
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
gorm.io/driver/mysql v1.3.3
gorm.io/driver/postgres v1.3.5
gorm.io/driver/sqlite v1.3.2
gorm.io/driver/sqlserver v1.3.2
gorm.io/gorm v1.23.4
)
replace gorm.io/gorm => ../

View File

@ -19,6 +19,7 @@ type Config struct {
Team int
Languages int
Friends int
NamedPet bool
}
func GetUser(name string, config Config) *User {
@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User {
user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{}))
}
if config.NamedPet {
user.NamedPet = &Pet{Name: name + "_namepet"}
}
return &user
}

View File

@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) {
t.Errorf("invalid data after update, got %+v", product)
}
// Code changed, price should changed
DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""})
if product.Name != "Product New4" || product.Price != 320 || product.Code != "" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).UpdateColumns(Product3{Code: "L1215"})
if product.Price != 270 || product.Code != "L1215" {
if product.Price != 320 || product.Code != "L1215" {
t.Errorf("invalid data after update, got %+v", product)
}
DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"})
if product.Price != 270 || product.Code != "L1216" {
if product.Price != 320 || product.Code != "L1216" {
t.Errorf("invalid data after update, got %+v", product)
}
@ -460,8 +466,9 @@ type Product4 struct {
type ProductItem struct {
gorm.Model
Code string
Product4ID uint
Code string
Product4ID uint
AfterFindCallTimes int
}
func (pi ProductItem) BeforeCreate(*gorm.DB) error {
@ -471,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error {
return nil
}
func (pi *ProductItem) AfterFind(*gorm.DB) error {
pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1
return nil
}
func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
DB.Migrator().DropTable(&Product4{}, &ProductItem{})
DB.AutoMigrate(&Product4{}, &ProductItem{})
@ -492,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) {
if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil {
t.Errorf("should find product, but got error %v", err)
}
var productWithItem Product4
if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil {
t.Errorf("should find product, but got error %v", err)
}
if productWithItem.Item.AfterFindCallTimes != 0 {
t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes)
}
}

View File

@ -10,12 +10,12 @@ import (
)
func TestJoins(t *testing.T) {
user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true})
user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false})
DB.Create(&user)
var user2 User
if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
}
@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) {
}
}
func TestJoinWithOmit(t *testing.T) {
user := *GetUser("joins_with_omit", Config{Pets: 2})
DB.Save(&user)
results := make([]*User, 0)
if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil {
return
}
if len(results) != 2 || results[0].Name != "" || results[1].Name != "" {
t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results)
return
}
}
func TestJoinCount(t *testing.T) {
companyA := Company{Name: "A"}
companyB := Company{Name: "B"}
@ -184,3 +200,32 @@ func TestJoinCount(t *testing.T) {
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
}
}
func TestJoinWithSoftDeleted(t *testing.T) {
user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true})
DB.Create(&user)
var user1 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID)
if user1.NamedPet == nil || user1.Account.ID == 0 {
t.Fatalf("joins NamedPet and Account should not empty:%v", user1)
}
// Account should empty
DB.Delete(&user1.Account)
var user2 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID)
if user2.NamedPet == nil || user2.Account.ID != 0 {
t.Fatalf("joins Account should not empty:%v", user2)
}
// NamedPet should empty
DB.Delete(&user1.NamedPet)
var user3 User
DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID)
if user3.NamedPet != nil || user2.Account.ID != 0 {
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
}
}

View File

@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) {
func TestSetAndGet(t *testing.T) {
if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
t.Errorf("Should be able to get setting after set")
} else {
if value.(string) != "world" {
t.Errorf("Setted value should not be changed")
}
} else if value.(string) != "world" {
t.Errorf("Set value should not be changed")
}
if _, ok := DB.Get("non_existing"); ok {

View File

@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) {
for _, m := range allModels {
if !DB.Migrator().HasTable(m) {
t.Fatalf("Failed to create table for %#v---", m)
t.Fatalf("Failed to create table for %#v", m)
}
}
@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
}
func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()]
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) {
DB.Migrator().DropTable("new_table_structs")
if DB.Migrator().HasTable(&NewTableStruct{}) {
t.Fatal("should not found droped table")
t.Fatal("should not found dropped table")
}
}
@ -313,9 +313,16 @@ func TestMigrateIndexes(t *testing.T) {
}
func TestMigrateColumns(t *testing.T) {
sqlite := DB.Dialector.Name() == "sqlite"
sqlserver := DB.Dialector.Name() == "sqlserver"
type ColumnStruct struct {
gorm.Model
Name string
Name string
Age int `gorm:"default:18;comment:my age"`
Code string `gorm:"unique;comment:my code;"`
Code2 string
Code3 string `gorm:"unique"`
}
DB.Migrator().DropTable(&ColumnStruct{})
@ -326,13 +333,20 @@ func TestMigrateColumns(t *testing.T) {
type ColumnStruct2 struct {
gorm.Model
Name string `gorm:"size:100"`
Name string `gorm:"size:100"`
Code string `gorm:"unique;comment:my code2;default:hello"`
Code2 string `gorm:"unique"`
// Code3 string
}
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil {
if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil {
t.Fatalf("no error should happened when alter column, but got %v", err)
}
if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil {
t.Fatalf("no error should happened when auto migrate column, but got %v", err)
}
if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil {
t.Fatalf("no error should returns for ColumnTypes")
} else {
@ -340,11 +354,45 @@ func TestMigrateColumns(t *testing.T) {
stmt.Parse(&ColumnStruct2{})
for _, columnType := range columnTypes {
if columnType.Name() == "name" {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code3":
// TODO
// if v, ok := columnType.Unique(); !ok || v {
// t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
// }
}
}
}
@ -526,3 +574,122 @@ func TestMigrateColumnOrder(t *testing.T) {
}
}
}
// https://github.com/go-gorm/gorm/issues/5047
func TestMigrateSerialColumn(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Event struct {
ID uint `gorm:"primarykey"`
UID uint32
}
type Event1 struct {
ID uint `gorm:"primarykey"`
UID uint32 `gorm:"not null;autoIncrement"`
}
type Event2 struct {
ID uint `gorm:"primarykey"`
UID uint16 `gorm:"not null;autoIncrement"`
}
var err error
err = DB.Migrator().DropTable(&Event{})
if err != nil {
t.Errorf("DropTable err:%v", err)
}
// create sequence
err = DB.Table("events").AutoMigrate(&Event1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// delete sequence
err = DB.Table("events").AutoMigrate(&Event{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
// update sequence
err = DB.Table("events").AutoMigrate(&Event1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = DB.Table("events").AutoMigrate(&Event2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
DB.Table("events").Save(&Event2{})
DB.Table("events").Save(&Event2{})
DB.Table("events").Save(&Event2{})
events := make([]*Event, 0)
DB.Table("events").Find(&events)
AssertEqual(t, 3, len(events))
for _, v := range events {
AssertEqual(t, v.ID, v.UID)
}
}
// https://github.com/go-gorm/gorm/issues/5300
func TestMigrateWithSpecialName(t *testing.T) {
var err error
err = DB.AutoMigrate(&Coupon{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
AssertEqual(t, true, DB.Migrator().HasTable("coupons"))
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1"))
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
}
// https://github.com/go-gorm/gorm/issues/5320
func TestPrimarykeyID(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type MissPKLanguage struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
Name string
}
type MissPKUser struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
}
var err error
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("DropTable err:%v", err)
}
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
// patch
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
}

View File

@ -2,6 +2,7 @@ package tests_test
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) {
type Harumph struct {
gorm.Model
Name string `gorm:"check:name_checker,name <> ''"`
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
Things pq.StringArray `gorm:"type:text[]"`
Name string `gorm:"check:name_checker,name <> ''"`
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
Things pq.StringArray `gorm:"type:text[]"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) {
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
harumph.Name = "jinzhu1"
if err := DB.Save(&harumph).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
}
type Post struct {

View File

@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) {
}
if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) {
t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want))
}
if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) {

View File

@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) {
}
wg.Wait()
}
func TestPreloadWithDiffModel(t *testing.T) {
user := *GetUser("preload_with_diff_model", Config{Account: true})
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
var result struct {
Something string
User
}
DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select(
"users.*, 'yo' as something").First(&result, "name = ?", user.Name)
CheckUser(t, user, result.User)
}

View File

@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) {
}
}
func TestFindInBatchesWithOffsetLimit(t *testing.T) {
users := []User{
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
*GetUser("find_in_batches_with_offset_limit", Config{}),
}
DB.Create(&users)
var (
sub, results []User
lastBatch int
)
// offset limit
if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error {
results = append(results, sub...)
lastBatch = batch
return nil
}); result.Error != nil || result.RowsAffected != 5 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
if lastBatch != 3 {
t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch)
}
targetUsers := users[3:8]
for i := 0; i < len(targetUsers); i++ {
AssertEqual(t, results[i], targetUsers[i])
}
var sub1 []User
// limit < batchSize
if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 5 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
var sub2 []User
// only offset
if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 7 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
var sub3 []User
if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error {
return nil
}); result.Error != nil || result.RowsAffected != 4 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}
}
func TestFindInBatchesWithError(t *testing.T) {
if name := DB.Dialector.Name(); name == "sqlserver" {
t.Skip("skip sqlserver due to it will raise data race for invalid sql")
@ -512,7 +574,13 @@ func TestNotWithAllFields(t *testing.T) {
func TestOr(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true})
result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{})
var count int64
result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count)
if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
@ -577,7 +645,9 @@ func TestPluck(t *testing.T) {
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil {
t.Errorf("got error when pluck name: %v", err)
}
AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2)))
sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] })
AssertEqual(t, names, names2)
var ids []int
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil {
@ -1152,3 +1222,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
}
type DoubleInt64 struct {
data int64
}
func (t *DoubleInt64) Scan(val interface{}) error {
switch v := val.(type) {
case int64:
t.data = v * 2
return nil
default:
return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
}
}
// https://github.com/go-gorm/gorm/issues/5091
func TestQueryScannerWithSingleColumn(t *testing.T) {
user := User{Name: "scanner_raw_1", Age: 10}
DB.Create(&user)
var result1 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
"age", &result1).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result1.data, 20)
var result2 DoubleInt64
if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
"age").Scan(&result2).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
AssertEqual(t, result2.data, 20)
}

View File

@ -10,6 +10,11 @@ import (
. "gorm.io/gorm/utils/tests"
)
type PersonAddressInfo struct {
Person *Person `gorm:"embedded"`
Address *Address `gorm:"embedded"`
}
func TestScan(t *testing.T) {
user1 := User{Name: "ScanUser1", Age: 1}
user2 := User{Name: "ScanUser2", Age: 10}
@ -156,3 +161,57 @@ func TestScanRows(t *testing.T) {
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
}
}
func TestScanToEmbedded(t *testing.T) {
person1 := Person{Name: "person 1"}
person2 := Person{Name: "person 2"}
DB.Save(&person1).Save(&person2)
address1 := Address{Name: "address 1"}
address2 := Address{Name: "address 2"}
DB.Save(&address1).Save(&address2)
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)})
DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)})
DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)})
var personAddressInfoList []*PersonAddressInfo
if err := DB.Select("people.*, addresses.*").
Table("people").
Joins("inner join person_addresses on people.id = person_addresses.person_id").
Joins("inner join addresses on person_addresses.address_id = addresses.id").
Find(&personAddressInfoList).Error; err != nil {
t.Errorf("Failed to run join query, got error: %v", err)
}
personMatched := false
addressMatched := false
for _, info := range personAddressInfoList {
if info.Person == nil {
t.Fatalf("Failed, expected not nil, got person nil")
}
if info.Address == nil {
t.Fatalf("Failed, expected not nil, got address nil")
}
if info.Person.ID == person1.ID {
personMatched = true
if info.Person.Name != person1.Name {
t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name)
}
}
if info.Address.ID == address1.ID {
addressMatched = true
if info.Address.Name != address1.Name {
t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name)
}
}
}
if !personMatched {
t.Errorf("Failed, no person matched")
}
if !addressMatched {
t.Errorf("Failed, no address matched")
}
}

138
tests/serializer_test.go Normal file
View File

@ -0,0 +1,138 @@
package tests_test
import (
"bytes"
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests"
)
type SerializerStruct struct {
gorm.Model
Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"`
Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
EncryptedString EncryptedString
}
type Roles []string
type Job struct {
Title string
Number int
Location string
IsIntern bool
}
type EncryptedString string
func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
switch value := dbValue.(type) {
case []byte:
*es = EncryptedString(bytes.TrimPrefix(value, []byte("hello")))
case string:
*es = EncryptedString(strings.TrimPrefix(value, "hello"))
default:
return fmt.Errorf("unsupported data %#v", dbValue)
}
return nil
}
func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
return "hello" + string(es), nil
}
func TestSerializer(t *testing.T) {
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt := createdAt.Unix()
data := SerializerStruct{
Name: []byte("jinzhu"),
Roles: []string{"r1", "r2"},
Contracts: map[string]interface{}{"name": "jinzhu", "age": 10},
EncryptedString: EncryptedString("pass"),
CreatedTime: createdAt.Unix(),
UpdatedTime: &updatedAt,
JobInfo: Job{
Title: "programmer",
Number: 9920,
Location: "Kenmawr",
IsIntern: false,
},
}
if err := DB.Create(&data).Error; err != nil {
t.Fatalf("failed to create data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, data)
}
func TestSerializerAssignFirstOrCreate(t *testing.T) {
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
data := SerializerStruct{
Name: []byte("ag9920"),
Roles: []string{"r1", "r2"},
Contracts: map[string]interface{}{"name": "jing1", "age": 11},
EncryptedString: EncryptedString("pass"),
CreatedTime: createdAt.Unix(),
JobInfo: Job{
Title: "programmer",
Number: 9920,
Location: "Shadyside",
IsIntern: false,
},
}
// first time insert record
out := SerializerStruct{}
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, out)
//update record
data.Roles = append(data.Roles, "r3")
data.JobInfo.Location = "Gates Hillman Complex"
if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil {
t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err)
}
if err := DB.First(&result, out.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result.Roles, data.Roles)
AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location)
}

View File

@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) {
}
}
type ageInt int8
func (ageInt) String() string {
return "age"
}
type ageBool bool
func (ageBool) String() string {
return "age"
}
type ageUint64 uint64
func (ageUint64) String() string {
return "age"
}
type ageFloat float64
func (ageFloat) String() string {
return "age"
}
func TestExplainSQL(t *testing.T) {
user := *GetUser("explain-sql", Config{})
dryRunDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement
sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement
sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) {
t.Errorf("Failed to generate sql, got %v", sql)
}
}
func TestGroupConditions(t *testing.T) {
type Pizza struct {
ID uint
@ -190,6 +243,21 @@ func TestGroupConditions(t *testing.T) {
if !strings.HasSuffix(result, expects) {
t.Errorf("expects: %v, got %v", expects, result)
}
stmt2 := dryRunDB.Where(
DB.Scopes(NameIn1And2),
).Or(
DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"),
).Find(&Pizza{}).Statement
execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement
result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...)
expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...)
if !strings.HasSuffix(result2, expects2) {
t.Errorf("expects: %v, got %v", expects2, result2)
}
}
func TestCombineStringConditions(t *testing.T) {
@ -307,7 +375,7 @@ func TestToSQL(t *testing.T) {
})
assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql)
// after model chagned
// after model changed
if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
}
@ -373,13 +441,13 @@ func TestToSQL(t *testing.T) {
})
assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql)
// after model chagned
// after model changed
if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
}
}
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals.
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
func assertEqualSQL(t *testing.T, expected string, actually string) {
t.Helper()
@ -387,7 +455,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
expected = replaceQuoteInSQL(expected)
actually = replaceQuoteInSQL(actually)
// ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update.
// ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update.
updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`)
actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`)
expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`)
@ -407,16 +475,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
func replaceQuoteInSQL(sql string) string {
// convert single quote into double quote
sql = strings.Replace(sql, `'`, `"`, -1)
sql = strings.ReplaceAll(sql, `'`, `"`)
// convert dialect speical quote into double quote
// convert dialect special quote into double quote
switch DB.Dialector.Name() {
case "postgres":
sql = strings.Replace(sql, `"`, `"`, -1)
sql = strings.ReplaceAll(sql, `"`, `"`)
case "mysql", "sqlite":
sql = strings.Replace(sql, "`", `"`, -1)
sql = strings.ReplaceAll(sql, "`", `"`)
case "sqlserver":
sql = strings.Replace(sql, `'`, `"`, -1)
sql = strings.ReplaceAll(sql, `'`, `"`)
}
return sql

View File

@ -15,6 +15,25 @@ then
cd ..
fi
# SqlServer for Mac M1
if [[ -z $GITHUB_ACTION ]]; then
if [ -d tests ]
then
cd tests
if [[ $(uname -a) == *" arm64" ]]; then
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
else
docker-compose start
fi
cd ..
fi
fi
for dialect in "${dialects[@]}" ; do
if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ]
then

View File

@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) {
PreferSimpleProtocol: true,
}), &gorm.Config{})
case "sqlserver":
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
// CREATE DATABASE gorm;
// USE gorm;
// GO
// CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
// CREATE USER gorm FROM LOGIN gorm;
// sp_changedbowner 'gorm';
// npm install -g sql-cli
// mssql -u gorm -p LoremIpsum86 -d gorm -o 9930
// ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];
// GO
log.Println("testing sqlserver...")
if dbDSN == "" {
dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
@ -94,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

View File

@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
dryDB := DB.Session(&gorm.Session{DryRun: true})
stmt := dryDB.Save(&user).Statement
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
}

View File

@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) {
tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user)
if err := tx.Error; err != nil {
t.Fatalf("failed to update user,missing where condtion,err=%+v", err)
t.Fatalf("failed to update user,missing where condition,err=%+v", err)
}
if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) {

View File

@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteString("`")
writer.WriteByte('`')
}
writer.WriteByte(v)
continue
@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) {
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString("``")
}
writer.WriteString("`")
writer.WriteByte('`')
}
func (DummyDialector) Explain(sql string, vars ...interface{}) string {

View File

@ -80,3 +80,17 @@ type Order struct {
Coupon *Coupon
CouponID string
}
type Parent struct {
gorm.Model
FavChildID uint
FavChild *Child
Children []*Child
}
type Child struct {
gorm.Model
Name string
ParentID *uint
Parent *Parent
}

View File

@ -83,20 +83,22 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
}
if reflect.ValueOf(got).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
exported := false
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
exported = true
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
if reflect.ValueOf(expect).Kind() == reflect.Struct {
if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() {
exported := false
for i := 0; i < reflect.ValueOf(got).NumField(); i++ {
if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) {
exported = true
field := reflect.ValueOf(got).Field(i)
t.Run(fieldStruct.Name, func(t *testing.T) {
AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface())
})
}
}
}
if exported {
return
if exported {
return
}
}
}
}
@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
} else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) {
expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface()
isEqual()
} else {
t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got)
return
}
}
}

View File

@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
}
func CheckTruth(val interface{}) bool {
if v, ok := val.(bool); ok {
return v
// CheckTruth check string true or not
func CheckTruth(vals ...string) bool {
for _, val := range vals {
if val != "" && !strings.EqualFold(val, "false") {
return true
}
}
if v, ok := val.(string); ok {
v = strings.ToLower(v)
return v != "false"
}
return !reflect.ValueOf(val).IsZero()
return false
}
func ToStringKey(values ...interface{}) string {