diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index dfd2ddd9..aa1812d4 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -3,20 +3,26 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:invalid question" diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index bc1add53..0e8aaa60 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index cdb097de..c3c92beb 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -3,19 +3,25 @@ on: schedule: - cron: "*/10 * * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 - days-before-close: 2 + days-before-close: 30 remove-stale-when-updated: true only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index b252dd7a..a6542d57 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index d5419295..af8d3636 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -3,19 +3,25 @@ on: schedule: - cron: "0 2 * * *" +permissions: + contents: read + jobs: stale: + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v4 + uses: actions/stale@v5 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" - days-before-stale: 60 - days-before-close: 30 + stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" + days-before-stale: 360 + days-before-close: 180 stale-issue-label: "status:stale" exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' stale-pr-label: 'status:stale' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 700af759..b97da3f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,38 +8,41 @@ on: branches-ignore: - 'gh-pages' +permissions: + contents: read + jobs: # Label of the container job sqlite: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlite ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh mysql: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -62,28 +65,28 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -106,26 +109,26 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: matrix: - go: ['1.17', '1.16'] + go: ['1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} @@ -149,18 +152,18 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: go mod package cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/.gitignore b/.gitignore index e1b9ecea..45505cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ documents coverage.txt _book .idea +vendor \ No newline at end of file diff --git a/README.md b/README.md index a3eabe39..312a3a59 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) +* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) ## Contributing diff --git a/association.go b/association.go index 62c25b71..35e10ddd 100644 --- a/association.go +++ b/association.go @@ -79,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -96,12 +96,12 @@ func (association *Association) Replace(values ...interface{}) error { primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} - relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) - if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } @@ -117,7 +117,7 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } @@ -143,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } @@ -186,11 +186,14 @@ func (association *Association) Delete(values ...interface{}) error { case schema.BelongsTo: tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -198,11 +201,14 @@ func (association *Association) Delete(values ...interface{}) error { case schema.HasOne, schema.HasMany: tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -228,11 +234,14 @@ func (association *Association) Delete(values ...interface{}) error { } } - _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) - pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) - conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) + if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + } else { + return ErrPrimaryKeyRequired + } - _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) @@ -241,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error { if association.Error == nil { // clean up deleted values's foreign key - relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { - if _, zero := rel.Field.ValueOf(data); !zero { - fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { @@ -253,7 +262,7 @@ func (association *Association) Delete(values ...interface{}) error { validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { @@ -261,23 +270,23 @@ func (association *Association) Delete(values ...interface{}) error { } } - association.Error = rel.Field.Set(data, validFieldValues.Interface()) + association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { - primaryValues[idx], _ = field.ValueOf(fieldValue) + primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { - if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { - association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { - association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -329,14 +338,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { - association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: - association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) @@ -344,7 +353,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) if clear { fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() } @@ -373,7 +382,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if association.Error == nil { - association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } @@ -421,7 +430,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { - if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } @@ -429,7 +438,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } @@ -453,12 +462,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ case reflect.Struct: // clear old data if clear && len(values) == 0 { - association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { - association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } @@ -475,7 +484,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } for _, assignBack := range assignBacks { - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { @@ -486,14 +495,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ func (association *Association) buildCondition() *DB { var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) if association.Relationship.JoinTable != nil { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } diff --git a/callbacks.go b/callbacks.go index f344649e..c060ea70 100644 --- a/callbacks.go +++ b/callbacks.go @@ -246,7 +246,13 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback func(*callback) error ) sort.Slice(cs, func(i, j int) bool { - return cs[j].before == "*" || cs[j].after == "*" + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false }) for _, c := range cs { diff --git a/callbacks/associations.go b/callbacks/associations.go index 75bd6c6a..fd3141cf 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -24,8 +24,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(elem) - db.AddError(ref.ForeignKey.Set(obj, pv)) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv @@ -57,8 +57,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { break } - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value objs = append(objs, obj) if isPtr { elems = reflect.Append(elems, rv) @@ -69,20 +69,20 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } if elems.Len() > 0 { - if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } } } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } - if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -120,18 +120,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) + if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) } } @@ -146,11 +146,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } case reflect.Struct: - if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { - f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } @@ -158,15 +158,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) - ref.ForeignKey.Set(f, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(f, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) } } } @@ -185,23 +185,23 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { - pv, _ := ref.PrimaryKey.ValueOf(v) - ref.ForeignKey.Set(elem, pv) + pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(elem, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) } } relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { - if pfv, ok := pf.ValueOf(elem); !ok { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } @@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } } @@ -260,21 +260,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } else if ref.PrimaryValue != "" { - ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) } else { - fv, _ := ref.PrimaryKey.ValueOf(elem) - ref.ForeignKey.Set(joinValue, fv) + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) + db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } } joins = reflect.Append(joins, joinValue) } appendToElems := func(v reflect.Value) { - if _, zero := rel.Field.ValueOf(v); !zero { - f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) @@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + saveAssociations(db, rel, elems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { @@ -323,7 +323,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { @@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[ return } -func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { +func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + // stop save association loop + if checkAssociationsSaved(db, rValues) { + return nil + } + var ( selects, omits []string - onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) refName = rel.Name + "." + values = rValues.Interface() ) for name, ok := range selectColumns { @@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, return db.AddError(tx.Create(values).Error) } + +// check association values has been saved +// if values kind is Struct, check it has been saved +// if values kind is Slice/Array, check all items have been saved +var visitMapStoreKey = "gorm:saved_association_map" + +func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { + if visit, ok := db.Get(visitMapStoreKey); ok { + if v, ok := visit.(*visitMap); ok { + if loadOrStoreVisitMap(v, values) { + return true + } + } + } else { + vistMap := make(visitMap) + loadOrStoreVisitMap(&vistMap, values) + db.Set(visitMapStoreKey, &vistMap) + } + + return false +} diff --git a/callbacks/create.go b/callbacks/create.go index 29113128..0fe1dc93 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/utils" ) +// BeforeCreate before create hooks func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -31,6 +32,7 @@ func BeforeCreate(db *gorm.DB) { } } +// Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") @@ -82,8 +84,10 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., ) if db.AddError(err) == nil { + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, mode) - db.AddError(rows.Close()) } return @@ -117,9 +121,9 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } @@ -130,38 +134,39 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } } } +// AfterCreate after create hooks func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { - if db.Statement.Schema.AfterSave { - if i, ok := value.(AfterSaveInterface); ok { - called = true - db.AddError(i.AfterSave(tx)) - } - } - if db.Statement.Schema.AfterCreate { if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } } + + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } return called }) } @@ -201,13 +206,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: rValLen := stmt.ReflectValue.Len() - stmt.SQL.Grow(rValLen * 18) - values.Values = make([][]interface{}, rValLen) if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } + stmt.SQL.Grow(rValLen * 18) + stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) + values.Values = make([][]interface{}, rValLen) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) @@ -219,23 +226,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface - field.Set(rv, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) + stmt.AddError(field.Set(stmt.Context, rv, curTime)) + values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(rv); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } @@ -259,23 +266,23 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] - if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface - field.Set(stmt.ReflectValue, field.DefaultValueInterface) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) + values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 7f1e09ce..84f446a3 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -42,7 +42,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { switch rel.Type { case schema.HasOne, schema.HasMany: - queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false @@ -97,7 +97,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { } } - _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) queryConds = append(queryConds, clause.IN{Column: column, Values: values}) @@ -118,12 +118,18 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -131,7 +137,7 @@ func Delete(config *Config) func(db *gorm.DB) { } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880..ae9fd8c5 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -1,6 +1,7 @@ package callbacks import ( + "reflect" "sort" "gorm.io/gorm" @@ -104,3 +105,48 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} + +type visitMap = map[reflect.Value]bool + +// Check if circular values, return true if loaded +func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice, reflect.Array: + loaded = true + for i := 0; i < v.Len(); i++ { + if !loadOrStoreVisitMap(visitMap, v.Index(i)) { + loaded = false + } + } + case reflect.Struct, reflect.Interface: + if v.CanAddr() { + p := v.Addr() + if _, ok := (*visitMap)[p]; ok { + return true + } + (*visitMap)[p] = true + } + } + + return +} diff --git a/callbacks/preload.go b/callbacks/preload.go index 41405a22..ea2570ba 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -10,10 +10,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { +func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( - reflectValue = db.Statement.ReflectValue - tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field @@ -22,11 +21,6 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload inlineConds []interface{} ) - db.Statement.Settings.Range(func(k, v interface{}) bool { - tx.Statement.Settings.Store(k, v) - return true - }) - if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) @@ -48,14 +42,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { - return + return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { + return err + } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -63,11 +59,11 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { - fieldValues[idx], _ = field.ValueOf(joinIndexValue) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { - joinFieldValues[idx], _ = field.ValueOf(joinIndexValue) + joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { @@ -76,7 +72,7 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -92,9 +88,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { - return + return nil } } @@ -115,7 +111,9 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload } } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { + return err + } } fieldValues := make([]interface{}, len(relForeignFields)) @@ -125,17 +123,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: - rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } @@ -143,18 +141,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { - fieldValues[idx], _ = field.ValueOf(elem) + fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { - db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", - elem.Interface())) - continue + return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { - reflectFieldValue := rel.Field.ReflectValueOf(data) + reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } @@ -162,14 +158,16 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: - rel.Field.Set(data, elem.Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } } + + return tx.Error } diff --git a/callbacks/query.go b/callbacks/query.go index c2bbf5f9..26ee8c34 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -20,8 +20,10 @@ func Query(db *gorm.DB) { db.AddError(err) return } + defer func() { + db.AddError(rows.Close()) + }() gorm.Scan(rows, db, 0) - db.AddError(rows.Close()) } } } @@ -40,7 +42,7 @@ func BuildQuerySQL(db *gorm.DB) { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { - if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } @@ -94,13 +96,13 @@ func BuildQuerySQL(db *gorm.DB) { } // inline joins - joins := []clause.Join{} - if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { - joins = fromClause.Joins + fromClause := clause.From{} + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause = v } - if len(db.Statement.Joins) != 0 || len(joins) != 0 { - if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { + if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} @@ -109,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) { for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { @@ -145,35 +147,49 @@ func BuildQuerySQL(db *gorm.DB) { } } - if join.On != nil { - onStmt := gorm.Statement{Table: tableAliasName, DB: db} - join.On.Build(&onStmt) - onSQL := onStmt.SQL.String() - vars := onStmt.Vars - for idx, v := range onStmt.Vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) } - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } } - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, }) } else { - joins = append(joins, clause.Join{ + fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } + db.Statement.AddClause(fromClause) db.Statement.Joins = nil - db.Statement.AddClause(clause.From{Joins: joins}) } else { db.Statement.AddClauseIfNotExists(clause.From{}) } @@ -186,6 +202,11 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { + if db.Statement.Schema == nil { + db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) + return + } + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") @@ -218,9 +239,20 @@ func Preload(db *gorm.DB) { } sort.Strings(preloadNames) + preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + preloadDB.Statement.Settings.Store(k, v) + return true + }) + + if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + return + } + preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + for _, name := range preloadNames { - if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { - preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) + if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } diff --git a/callbacks/update.go b/callbacks/update.go index 511e994e..01f40509 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -21,7 +21,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { - rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) } } } @@ -29,6 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } } +// BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { @@ -51,6 +52,7 @@ func BeforeUpdate(db *gorm.DB) { } } +// Update update hook func Update(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") @@ -59,6 +61,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +76,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { @@ -105,9 +101,17 @@ func Update(config *Config) func(db *gorm.DB) { } } +// AfterUpdate after update hooks func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { called = true @@ -115,12 +119,6 @@ func AfterUpdate(db *gorm.DB) { } } - if db.Statement.Schema.AfterUpdate { - if i, ok := value.(AfterUpdateInterface); ok { - called = true - db.AddError(i.AfterUpdate(tx)) - } - } return called }) } @@ -137,13 +135,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { - field.Set(stmt.ReflectValue, value) + field.Set(stmt.Context, stmt.ReflectValue, value) } } default: @@ -165,7 +163,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) exprs[idx] = clause.Eq{Column: field.DBName, Value: value} notZero = notZero || !isZero } @@ -178,7 +176,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { - if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } @@ -232,10 +230,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) - } else if field.GORMDataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } } } @@ -258,16 +256,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) + value, isZero := field.ValueOf(stmt.Context, updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() } isZero = false } @@ -278,7 +276,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { + if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go new file mode 100644 index 00000000..b1fb86db --- /dev/null +++ b/callbacks/visit_map_test.go @@ -0,0 +1,36 @@ +package callbacks + +import ( + "reflect" + "testing" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} diff --git a/chainable_api.go b/chainable_api.go index 173479d3..68b4d1aa 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -54,9 +54,12 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - } else { + } else if name != "" { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = name + } else { + tx.Statement.TableExpr = nil + tx.Statement.Table = "" } return } @@ -90,7 +93,11 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } - delete(tx.Statement.Clauses, "SELECT") + + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } case string: if strings.Count(v, "?") >= len(args) && len(args) > 0 { tx.Statement.AddClause(clause.Select{ @@ -120,7 +127,10 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } } - delete(tx.Statement.Clauses, "SELECT") + if clause, ok := tx.Statement.Clauses["SELECT"]; ok { + clause.Expression = nil + tx.Statement.Clauses["SELECT"] = clause + } } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) diff --git a/clause/limit.go b/clause/limit.go index 2082f4d9..184f6025 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -21,7 +21,7 @@ func (limit Limit) Build(builder Builder) { } if limit.Offset > 0 { if limit.Limit > 0 { - builder.WriteString(" ") + builder.WriteByte(' ') } builder.WriteString("OFFSET ") builder.WriteString(strconv.Itoa(limit.Offset)) diff --git a/clause/select_test.go b/clause/select_test.go index 9fce0783..18bc2693 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -43,6 +43,23 @@ func TestSelect(t *testing.T) { }, clause.From{}}, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "? as name", + Vars: []interface{}{clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, + }, + }, + }, + }, + }, clause.From{}}, + "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + }, } for idx, result := range results { diff --git a/clause/where.go b/clause/where.go index 20a01136..a29401cf 100644 --- a/clause/where.go +++ b/clause/where.go @@ -4,6 +4,11 @@ import ( "strings" ) +const ( + AndWithSpace = " AND " + OrWithSpace = " OR " +) + // Where where clause type Where struct { Exprs []Expression @@ -26,7 +31,7 @@ func (where Where) Build(builder Builder) { } } - buildExprs(where.Exprs, builder, " AND ") + buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { @@ -35,7 +40,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { - builder.WriteString(" OR ") + builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } @@ -46,30 +51,30 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { - sql := strings.ToLower(e.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(e.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) case NamedExpr: - sql := strings.ToLower(v.SQL) - wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + sql := strings.ToUpper(v.SQL) + wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } if wrapInParentheses { - builder.WriteString(`(`) + builder.WriteByte('(') expr.Build(builder) - builder.WriteString(`)`) + builder.WriteByte(')') wrapInParentheses = false } else { expr.Build(builder) @@ -110,10 +115,10 @@ type AndConditions struct { func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { - buildExprs(and.Exprs, builder, " AND ") + buildExprs(and.Exprs, builder, AndWithSpace) } } @@ -131,10 +136,10 @@ type OrConditions struct { func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { - buildExprs(or.Exprs, builder, " OR ") + buildExprs(or.Exprs, builder, OrWithSpace) } } @@ -156,7 +161,7 @@ func (not NotConditions) Build(builder Builder) { for idx, c := range not.Exprs { if idx > 0 { - builder.WriteString(" AND ") + builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { @@ -165,8 +170,8 @@ func (not NotConditions) Build(builder Builder) { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { - sql := strings.ToLower(e.SQL) - if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } diff --git a/clause/where_test.go b/clause/where_test.go index 272c7b76..35e3dbee 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -66,6 +66,45 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", + []interface{}{"1", 18, 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{ + clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), + }, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + []interface{}{"1", 100}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", + []interface{}{"1", 100}, + }, } for idx, result := range results { diff --git a/errors.go b/errors.go index 145614d9..49cbfe64 100644 --- a/errors.go +++ b/errors.go @@ -39,4 +39,6 @@ var ( ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") // ErrInvalidValueOfLength invalid values do not match length ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") + // ErrPreloadNotAllowed preload is not allowed when count is used + ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") ) diff --git a/finisher_api.go b/finisher_api.go index bbb1ba4d..89a7e296 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -74,6 +74,10 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Dest = value reflectValue := reflect.Indirect(reflect.ValueOf(value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { @@ -83,7 +87,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { - if _, isZero := pf.ValueOf(reflectValue); isZero { + if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { return tx.callbacks.Create().Execute(tx) } } @@ -101,7 +105,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { + if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { return tx.Create(value) } } @@ -177,6 +181,21 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch int ) + // user specified offset or limit + var totalSize int + if c, ok := tx.Statement.Clauses["LIMIT"]; ok { + if limit, ok := c.Expression.(clause.Limit); ok { + totalSize = limit.Limit + + if totalSize > 0 && batchSize > totalSize { + batchSize = totalSize + } + + // reset to offset to 0 in next batch + tx = tx.Offset(-1).Session(&Session{}) + } + } + for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected @@ -192,6 +211,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } + if totalSize > 0 { + if totalSize <= int(rowsAffected) { + break + } + if totalSize/batchSize == batch { + batchSize = totalSize % batchSize + } + } + // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) if result.Statement.Schema.PrioritizedPrimaryField == nil { @@ -199,7 +227,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -207,7 +235,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return tx } -func (tx *DB) assignInterfacesToValue(values ...interface{}) { +func (db *DB) assignInterfacesToValue(values ...interface{}) { for _, value := range values { switch v := value.(type) { case []clause.Expression: @@ -215,40 +243,40 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + if field := db.Statement.Schema.LookUpField(column.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { - tx.assignInterfacesToValue(andCond.Exprs) + db.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: - if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } default: - if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: for _, f := range s.Fields { if f.Readable { - if v, isZero := f.ValueOf(reflectValue); !isZero { - if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { + if field := db.Statement.Schema.LookUpField(f.Name); field != nil { + db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) } } } } } } else if len(values) > 0 { - if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { - tx.assignInterfacesToValue(exprs) + if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + db.assignInterfacesToValue(exprs) } return } @@ -256,12 +284,13 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } } +// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -281,26 +310,28 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } +// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - queryTx := db.Limit(1).Order(clause.OrderByColumn{ + tx = db.getInstance() + queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if tx = queryTx.Find(dest, conds...); tx.Error == nil { - if tx.RowsAffected == 0 { - if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if result := queryTx.Find(dest, conds...); result.Error == nil { + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignInterfacesToValue(where.Exprs) + result.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds - if len(tx.Statement.attrs) > 0 { - tx.assignInterfacesToValue(tx.Statement.attrs...) + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) } // initialize with attrs, conds - if len(tx.Statement.assigns) > 0 { - tx.assignInterfacesToValue(tx.Statement.assigns...) + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) } return tx.Create(dest) @@ -320,6 +351,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { } return tx.Model(dest).Updates(assigns) + } else { + tx.Error = result.Error } } return tx @@ -585,7 +618,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) + tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions err error ) @@ -594,11 +627,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else { + default: err = ErrInvalidTransaction } diff --git a/gorm.go b/gorm.go index fc70f684..6a6bb032 100644 --- a/gorm.go +++ b/gorm.go @@ -59,6 +59,7 @@ type Config struct { cacheStore *sync.Map } +// Apply update config to new config func (c *Config) Apply(config *Config) error { if config != c { *config = *c @@ -66,6 +67,7 @@ func (c *Config) Apply(config *Config) error { return nil } +// AfterInitialize initialize plugins after db connected func (c *Config) AfterInitialize(db *DB) error { if db != nil { for _, plugin := range c.Plugins { @@ -77,6 +79,7 @@ func (c *Config) AfterInitialize(db *DB) error { return nil } +// Option gorm option interface type Option interface { Apply(*Config) error AfterInitialize(*DB) error @@ -96,6 +99,7 @@ type Session struct { DryRun bool PrepareStmt bool NewDB bool + Initialized bool SkipHooks bool SkipDefaultTransaction bool DisableNestedTransaction bool @@ -120,8 +124,8 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { for _, opt := range opts { if opt != nil { - if err := opt.Apply(config); err != nil { - return nil, err + if applyErr := opt.Apply(config); applyErr != nil { + return nil, applyErr } defer func(opt Option) { if errr := opt.AfterInitialize(db); errr != nil { @@ -282,6 +286,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.NowFunc = config.NowFunc } + if config.Initialized { + tx = tx.getInstance() + } + return tx } @@ -376,10 +384,12 @@ func (db *DB) getInstance() *DB { return db } +// Expr returns clause.Expr, which can be used to pass SQL expression as params func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } +// SetupJoinTable setup join table schema func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { var ( tx = db.getInstance() @@ -430,6 +440,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac return nil } +// Use use plugin func (db *DB) Use(plugin Plugin) error { name := plugin.Name() if _, ok := db.Plugins[name]; ok { @@ -451,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error { // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { - tx := queryFn(db.Session(&Session{DryRun: true})) + tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) diff --git a/interfaces.go b/interfaces.go index 44b2fced..32d49605 100644 --- a/interfaces.go +++ b/interfaces.go @@ -40,24 +40,45 @@ type SavePointerDialectorInterface interface { RollbackTo(tx *DB, name string) error } +// TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +// ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } +// TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } +// Tx sql.Tx interface +type Tx interface { + ConnPool + TxCommitter + StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt +} + // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr } +// GetDBConnector SQL db connector type GetDBConnector interface { GetDBConn() (*sql.DB, error) } + +// Rows rows interface +type Rows interface { + Columns() ([]string, error) + ColumnTypes() ([]*sql.ColumnType, error) + Next() bool + Scan(dest ...interface{}) error + Err() error + Close() error +} diff --git a/logger/logger.go b/logger/logger.go index 0c4ca4a0..2ffd28d5 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm/utils" ) +// ErrRecordNotFound record not found error var ErrRecordNotFound = errors.New("record not found") // Colors @@ -30,13 +31,17 @@ const ( YellowBold = "\033[33;1m" ) -// LogLevel +// LogLevel log level type LogLevel int const ( + // Silent silent log level Silent LogLevel = iota + 1 + // Error error log level Error + // Warn warn log level Warn + // Info info log level Info ) @@ -45,6 +50,7 @@ type Writer interface { Printf(string, ...interface{}) } +// Config logger config type Config struct { SlowThreshold time.Duration Colorful bool @@ -62,16 +68,20 @@ type Interface interface { } var ( + // Discard Discard logger will print any log to ioutil.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, Colorful: true, }) + // Recorder Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) +// New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " @@ -179,10 +189,12 @@ type traceRecorder struct { Err error } +// New new trace recorder func (l traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } +// Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() diff --git a/logger/sql.go b/logger/sql.go index 5ecb0ae2..c8b194c3 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -19,9 +19,9 @@ const ( nullStr = "NULL" ) -func isPrintable(s []byte) bool { +func isPrintable(s string) bool { for _, r := range s { - if !unicode.IsPrint(rune(r)) { + if !unicode.IsPrint(r) { return false } } @@ -30,9 +30,12 @@ func isPrintable(s []byte) bool { var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { - var convertParams func(interface{}, int) - vars := make([]string, len(avars)) + var ( + convertParams func(interface{}, int) + vars = make([]string, len(avars)) + ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { @@ -64,14 +67,25 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case fmt.Stringer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper - } else { - vars[idx] = nullStr + switch reflectValue.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) + case reflect.Float32, reflect.Float64: + vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) + case reflect.Bool: + vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) + case reflect.String: + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + default: + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + } else { + vars[idx] = nullStr + } } case []byte: - if isPrintable(v) { - vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + if s := string(v); isPrintable(s) { + vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } @@ -80,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: - vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -97,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 71aa841a..c5b181a9 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper } func TestExplainSQL(t *testing.T) { diff --git a/migrator.go b/migrator.go index 2a8b4254..52443877 100644 --- a/migrator.go +++ b/migrator.go @@ -1,6 +1,8 @@ package gorm import ( + "reflect" + "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -33,14 +35,23 @@ type ViewOption struct { Query *DB } +// ColumnType column type interface type ColumnType interface { Name() string - DatabaseTypeName() string + DatabaseTypeName() string // varchar + ColumnType() (columnType string, ok bool) // varchar(64) + PrimaryKey() (isPrimaryKey bool, ok bool) + AutoIncrement() (isAutoIncrement bool, ok bool) Length() (length int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool) Nullable() (nullable bool, ok bool) + Unique() (unique bool, ok bool) + ScanType() reflect.Type + Comment() (value string, ok bool) + DefaultValue() (value string, ok bool) } +// Migrator migrator interface type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error diff --git a/migrator/column_type.go b/migrator/column_type.go new file mode 100644 index 00000000..c6fdd6b2 --- /dev/null +++ b/migrator/column_type.go @@ -0,0 +1,107 @@ +package migrator + +import ( + "database/sql" + "reflect" +) + +// ColumnType column type implements ColumnType interface +type ColumnType struct { + SQLColumnType *sql.ColumnType + NameValue sql.NullString + DataTypeValue sql.NullString + ColumnTypeValue sql.NullString + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + AutoIncrementValue sql.NullBool + LengthValue sql.NullInt64 + DecimalSizeValue sql.NullInt64 + ScaleValue sql.NullInt64 + NullableValue sql.NullBool + ScanTypeValue reflect.Type + CommentValue sql.NullString + DefaultValueValue sql.NullString +} + +// Name returns the name or alias of the column. +func (ct ColumnType) Name() string { + if ct.NameValue.Valid { + return ct.NameValue.String + } + return ct.SQLColumnType.Name() +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ct ColumnType) DatabaseTypeName() string { + if ct.DataTypeValue.Valid { + return ct.DataTypeValue.String + } + return ct.SQLColumnType.DatabaseTypeName() +} + +// ColumnType returns the database type of the column. like `varchar(16)` +func (ct ColumnType) ColumnType() (columnType string, ok bool) { + return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid +} + +// PrimaryKey returns the column is primary key or not. +func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid +} + +// AutoIncrement returns the column is auto increment or not. +func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { + return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid +} + +// Length returns the column type length for variable length column types +func (ct ColumnType) Length() (length int64, ok bool) { + if ct.LengthValue.Valid { + return ct.LengthValue.Int64, true + } + return ct.SQLColumnType.Length() +} + +// DecimalSize returns the scale and precision of a decimal type. +func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { + if ct.DecimalSizeValue.Valid { + return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true + } + return ct.SQLColumnType.DecimalSize() +} + +// Nullable reports whether the column may be null. +func (ct ColumnType) Nullable() (nullable bool, ok bool) { + if ct.NullableValue.Valid { + return ct.NullableValue.Bool, true + } + return ct.SQLColumnType.Nullable() +} + +// Unique reports whether the column may be unique. +func (ct ColumnType) Unique() (unique bool, ok bool) { + return ct.UniqueValue.Bool, ct.UniqueValue.Valid +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +func (ct ColumnType) ScanType() reflect.Type { + if ct.ScanTypeValue != nil { + return ct.ScanTypeValue + } + return ct.SQLColumnType.ScanType() +} + +// Comment returns the comment of current column. +func (ct ColumnType) Comment() (value string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} + +// DefaultValue returns the default value of current column. +func (ct ColumnType) DefaultValue() (value string, ok bool) { + return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 138917fb..d4989410 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -30,10 +30,12 @@ type Config struct { gorm.Dialector } +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } +// RunWithValue run migration with statement value func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { @@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error return fc(stmt) } +// DataTypeOf return field's db data type func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { @@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { return m.Dialector.DataTypeOf(field) } +// FullDataTypeOf returns field's db full data type func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) @@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } -// AutoMigrate +// AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) @@ -95,7 +99,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + columnTypes, err := m.DB.Migrator().ColumnTypes(value) + if err != nil { + return err + } for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] @@ -156,12 +163,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return nil } +// GetTables returns tables func (m Migrator) GetTables() (tableList []string, err error) { err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). Scan(&tableList).Error return } +// CreateTable create table in database for values func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) @@ -252,6 +261,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { return nil } +// DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { @@ -265,6 +275,7 @@ func (m Migrator) DropTable(values ...interface{}) error { return nil } +// HasTable returns table exists or not for value, value could be a struct or string func (m Migrator) HasTable(value interface{}) bool { var count int64 @@ -276,6 +287,7 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } +// RenameTable rename table from oldName to newName func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable interface{} if v, ok := oldName.(string); ok { @@ -303,12 +315,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } -func (m Migrator) AddColumn(value interface{}, field string) error { +// AddColumn create `name` column for value +func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field - f := stmt.Schema.LookUpField(field) + f := stmt.Schema.LookUpField(name) if f == nil { - return fmt.Errorf("failed to look up field with name: %s", field) + return fmt.Errorf("failed to look up field with name: %s", name) } if !f.IgnoreMigration { @@ -322,6 +335,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { }) } +// DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { @@ -334,10 +348,11 @@ func (m Migrator) DropColumn(value interface{}, name string) error { }) } +// AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { - fileType := clause.Expr{SQL: m.DataTypeOf(field)} + fileType := m.FullDataTypeOf(field) return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, @@ -348,6 +363,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { }) } +// HasColumn check has column `field` for value or not func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -366,6 +382,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { return count > 0 } +// RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(oldName); field != nil { @@ -383,6 +400,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +// MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) @@ -421,6 +439,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } @@ -448,7 +490,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) + columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) } return @@ -457,10 +499,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } +// CreateView create view func (m Migrator) CreateView(name string, option gorm.ViewOption) error { return gorm.ErrNotImplemented } +// DropView drop view func (m Migrator) DropView(name string) error { return gorm.ErrNotImplemented } @@ -487,6 +531,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } +// GuessConstraintAndTable guess statement's constraint and it's table based on name func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { if stmt.Schema == nil { return nil, nil, stmt.Table @@ -531,6 +576,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ return nil, nil, stmt.Schema.Table } +// CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -554,6 +600,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { }) } +// DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -566,6 +613,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { }) } +// HasConstraint check has constraint or not func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -586,6 +634,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } +// BuildIndexOptions build index options func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) @@ -607,10 +656,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } +// BuildIndexOptionsInterface build index options interface type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } +// CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -642,6 +693,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { }) } +// DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -652,6 +704,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { }) } +// HasIndex check has index `name` or not func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -669,6 +722,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } +// RenameIndex rename index from oldName to newName func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( @@ -678,6 +732,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error }) } +// CurrentDatabase returns current database name func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return @@ -704,7 +759,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} - if err := dep.Parse(value); err != nil { + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { @@ -781,6 +837,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i return } +// CurrentTable returns current statement's table expression func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { if stmt.TableExpr != nil { return *stmt.TableExpr diff --git a/prepare_stmt.go b/prepare_stmt.go index 88bec4e9..b062b0d6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg } type PreparedStmtTX struct { - *sql.Tx + Tx PreparedStmtDB *PreparedStmtDB } @@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/scan.go b/scan.go index b03b79b4..ad3734d8 100644 --- a/scan.go +++ b/scan.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm/schema" ) +// prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { @@ -49,65 +50,56 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { - for idx, column := range columns { - if sch == nil { - values[idx] = reflectValue.Interface() - } else if field := sch.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() - continue - } +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { + for idx, field := range fields { + if field != nil { + values[idx] = field.NewValuePool.Get() + } else if len(fields) == 1 { + if reflectValue.CanAddr() { + values[idx] = reflectValue.Addr().Interface() + } else { + values[idx] = reflectValue.Interface() } - values[idx] = &sql.RawBytes{} - } else if len(columns) == 1 { - sch = nil - values[idx] = reflectValue.Interface() - } else { - values[idx] = &sql.RawBytes{} } } db.RowsAffected++ db.AddError(rows.Scan(values...)) - if sch != nil { - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - field.Set(reflectValue, values[idx]) - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - relValue := rel.Field.ReflectValueOf(reflectValue) - value := reflect.ValueOf(values[idx]).Elem() - - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value.IsNil() { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - field.Set(relValue, values[idx]) + for idx, field := range fields { + if field != nil { + if len(joinFields) == 0 || joinFields[idx][0] == nil { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { + relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue } + + relValue.Set(reflect.New(relValue.Type().Elem())) } + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } } +// ScanMode scan data mode type ScanMode uint8 +// scan modes const ( ScanInitialized ScanMode = 1 << 0 // 1 ScanUpdate ScanMode = 1 << 1 // 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) -func Scan(rows *sql.Rows, db *DB, mode ScanMode) { +// Scan scan rows into db statement +func Scan(rows Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() values = make([]interface{}, len(columns)) @@ -138,7 +130,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } scanIntoMap(mapValue, values, columns) } - case *[]map[string]interface{}, []map[string]interface{}: + case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) @@ -149,11 +141,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) - if values, ok := dest.([]map[string]interface{}); ok { - values = append(values, mapValue) - } else if values, ok := dest.(*[]map[string]interface{}); ok { - *values = append(*values, mapValue) - } + *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, @@ -168,10 +156,11 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + selectedColumnsMap = make(map[string]int, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -193,35 +182,49 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } - for idx, column := range columns { - if field := sch.LookUpField(column); field != nil && field.Readable { - fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field - - if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) - } - joinFields[idx] = [2]*schema.Field{rel.Field, field} - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if len(columns) == 1 { - // isPluck + // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } + + // Not Pluck + if sch != nil { + for idx, column := range columns { + if field := sch.LookUpField(column); field != nil && field.Readable { + if curIndex, ok := selectedColumnsMap[column]; ok { + for fieldIndex, selectField := range sch.Fields[curIndex+1:] { + if selectField.DBName == column && selectField.Readable { + selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = selectField + break + } + } + } else { + fields[idx] = field + selectedColumnsMap[column] = idx + } + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := sch.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } } switch reflectValue.Kind() { @@ -244,7 +247,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflectValue.Index(int(db.RowsAffected)) if onConflictDonothing { for _, field := range fields { - if _, ok := field.ValueOf(elem); !ok { + if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { db.RowsAffected++ goto BEGIN } @@ -254,7 +257,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { elem = reflect.New(reflectValueType) } - db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) + db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if isPtr { @@ -270,7 +273,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { - db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) + db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest)) diff --git a/schema/field.go b/schema/field.go index 485bbdf3..d6df6596 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,6 +1,7 @@ package schema import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -11,15 +12,25 @@ import ( "time" "github.com/jinzhu/now" + "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) -type DataType string +// special types' reflect type +var ( + TimeReflectType = reflect.TypeOf(time.Time{}) + TimePtrReflectType = reflect.TypeOf(&time.Time{}) + ByteReflectType = reflect.TypeOf(uint8(0)) +) -type TimeType int64 - -var TimeReflectType = reflect.TypeOf(time.Time{}) +type ( + // DataType GORM data type + DataType string + // TimeType GORM time type + TimeType int64 +) +// GORM time types const ( UnixTime TimeType = 1 UnixSecond TimeType = 2 @@ -27,6 +38,7 @@ const ( UnixNanosecond TimeType = 4 ) +// GORM fields types const ( Bool DataType = "bool" Int DataType = "int" @@ -37,6 +49,7 @@ const ( Bytes DataType = "bytes" ) +// Field is the representation of model schema's field type Field struct { Name string DBName string @@ -49,9 +62,9 @@ type Field struct { Creatable bool Updatable bool Readable bool - HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType + HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool @@ -60,6 +73,7 @@ type Field struct { Size int Precision int Scale int + IgnoreMigration bool FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -68,27 +82,39 @@ type Field struct { Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema - ReflectValueOf func(reflect.Value) reflect.Value - ValueOf func(reflect.Value) (value interface{}, zero bool) - Set func(reflect.Value, interface{}) error - IgnoreMigration bool + ReflectValueOf func(context.Context, reflect.Value) reflect.Value + ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) + Set func(context.Context, reflect.Value, interface{}) error + Serializer SerializerInterface + NewValuePool FieldNewValuePool } +// ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { - var err error + var ( + err error + tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") + ) field := &Field{ Name: fieldStruct.Name, + DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, + Tag: fieldStruct.Tag, + TagSettings: tagSetting, + Schema: schema, Creatable: true, Updatable: true, Readable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), - Schema: schema, + PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), + AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), + NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), + Unique: utils.CheckTruth(tagSetting["UNIQUE"]), + Comment: tagSetting["COMMENT"], AutoIncrementIncrement: 1, } @@ -97,7 +123,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - // if field is valuer, used its value or first fields as data type + // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { @@ -105,31 +131,37 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } + // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { - rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { - for i := 0; i < rv.Type().NumField(); i++ { - newFieldType := rv.Type().Field(i).Type + var ( + rv = reflect.Indirect(v) + rvType = rv.Type() + ) + + if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { + for i := 0; i < rvType.NumField(); i++ { + for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + + for i := 0; i < rvType.NumField(); i++ { + newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - - if rv.Type() != reflect.Indirect(fieldValue).Type() { + if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } - - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } } } } @@ -138,19 +170,23 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if dbName, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = dbName - } - - if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - field.PrimaryKey = true - } - - if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { - field.AutoIncrement = true - field.HasDefaultValue = true + if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { + field.DataType = String + field.Serializer = v + } else { + var serializerName = field.TagSettings["JSON"] + if serializerName == "" { + serializerName = field.TagSettings["SERIALIZER"] + } + if serializerName != "" { + if serializer, ok := GetSerializer(serializerName); ok { + // Set default data type to string for serializer + field.DataType = String + field.Serializer = serializer + } else { + schema.err = fmt.Errorf("invalid serializer type %v", serializerName) + } + } } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { @@ -176,20 +212,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Scale, _ = strconv.Atoi(s) } - if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { - field.NotNull = true - } - - if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { - field.Unique = true - } - - if val, ok := field.TagSettings["COMMENT"]; ok { - field.Comment = val - } - // default value is function or null or blank (primary keys) field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && @@ -225,7 +247,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } case reflect.String: field.DataType = String - if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, `"`) @@ -236,22 +257,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } + if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { + if t, err := now.Parse(field.DefaultValue); err == nil { + field.DefaultValueInterface = t + } + } case reflect.Array, reflect.Slice: - if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes } } - field.GORMDataType = field.DataType - if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoCreateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { @@ -263,7 +287,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoUpdateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { @@ -275,6 +299,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if val, ok := field.TagSettings["TYPE"]; ok { switch DataType(strings.ToLower(val)) { case Bool, Int, Uint, Float, String, Time, Bytes: @@ -284,10 +312,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.GORMDataType == "" { - field.GORMDataType = field.DataType - } - if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: @@ -346,8 +370,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; field.GORMDataType != Time && field.GORMDataType != Bytes && - (ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable))) { + // Normal anonymous field or having `EMBEDDED` tag + if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && + fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: @@ -410,31 +435,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // ValueOf + // Setup NewValuePool + field.setupNewValuePool() + + // ValueOf returns field's value and if it is zero + fieldIndex := field.StructField.Index[0] switch { - case len(field.StructField.Index) == 1: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) - return fieldValue.Interface(), fieldValue.IsZero() - } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: - field.ValueOf = func(value reflect.Value) (interface{}, bool) { - v := reflect.Indirect(value) - - for _, idx := range field.StructField.Index { - if idx >= 0 { - v = v.Field(idx) + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + v = reflect.Indirect(v) + for _, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) } else { - v = v.Field(-idx - 1) - - if v.Type().Elem().Kind() != reflect.Struct { - return nil, true - } + v = v.Field(-fieldIdx - 1) if !v.IsNil() { v = v.Elem() @@ -443,35 +462,52 @@ func (field *Field) setupValuerAndSetter() { } } } - return v.Interface(), v.IsZero() + + fv, zero := v.Interface(), v.IsZero() + return fv, zero } } - // ReflectValueOf - switch { - case len(field.StructField.Index) == 1: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]) + if field.Serializer != nil { + oldValuerOf := field.ValueOf + field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { + value, zero := oldValuerOf(ctx, v) + if zero { + return value, zero + } + + s, ok := value.(SerializerValuerInterface) + if !ok { + s = field.Serializer + } + + return &serializer{ + Field: field, + SerializeValuer: s, + Destination: v, + Context: ctx, + fieldValue: value, + }, false } - case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + } + + // ReflectValueOf returns field's reflect value + switch { + case len(field.StructField.Index) == 1 && fieldIndex > 0: + field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(fieldIndex) } default: - field.ReflectValueOf = func(value reflect.Value) reflect.Value { - v := reflect.Indirect(value) + field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { + v = reflect.Indirect(v) for idx, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) } else { v = v.Field(-fieldIdx - 1) - } - if v.Kind() == reflect.Ptr { - if v.Type().Elem().Kind() == reflect.Struct { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } if idx < len(field.StructField.Index)-1 { @@ -483,22 +519,25 @@ func (field *Field) setupValuerAndSetter() { } } - fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) // Optimal value type acquisition for v reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { + reflectV = reflect.Indirect(reflectV) + } + field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) fieldType := field.FieldType.Elem() if reflectValType.AssignableTo(fieldType) { @@ -521,16 +560,19 @@ func (field *Field) setupValuerAndSetter() { if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { + field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) + return } else { - err = setter(value, reflectV.Elem().Interface()) + err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = setter(value, v) + err = setter(ctx, value, v) } - } else { - return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) + } else if _, ok := v.(clause.Expr); !ok { + return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) } } @@ -540,191 +582,201 @@ func (field *Field) setupValuerAndSetter() { // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **bool: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetBool(**data) + } case bool: - field.ReflectValueOf(value).SetBool(data) - case *bool: - if data != nil { - field.ReflectValueOf(value).SetBool(*data) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data) case int64: - if data > 0 { - field.ReflectValueOf(value).SetBool(true) - } else { - field.ReflectValueOf(value).SetBool(false) - } + field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) - field.ReflectValueOf(value).SetBool(b) + field.ReflectValueOf(ctx, value).SetBool(b) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **int64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(**data) + } case int64: - field.ReflectValueOf(value).SetInt(data) + field.ReflectValueOf(ctx, value).SetInt(data) case int: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: - field.ReflectValueOf(value).SetInt(int64(data)) + field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetInt(i) + field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetInt(data.UnixNano()) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) } else { - field.ReflectValueOf(value).SetInt(data.Unix()) + field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { - field.ReflectValueOf(value).SetInt(0) + field.ReflectValueOf(ctx, value).SetInt(0) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **uint64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(**data) + } case uint64: - field.ReflectValueOf(value).SetUint(data) + field.ReflectValueOf(ctx, value).SetUint(data) case uint: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: - field.ReflectValueOf(value).SetUint(uint64(data)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) } else { - field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValueOf(value).SetUint(i) + field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **float64: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(**data) + } case float64: - field.ReflectValueOf(value).SetFloat(data) + field.ReflectValueOf(ctx, value).SetFloat(data) case float32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: - field.ReflectValueOf(value).SetFloat(float64(data)) + field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: - return field.Set(value, string(data)) + return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValueOf(value).SetFloat(i) + field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { + case **string: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetString(**data) + } case string: - field.ReflectValueOf(value).SetString(data) + field.ReflectValueOf(ctx, value).SetString(data) case []byte: - field.ReflectValueOf(value).SetString(string(data)) + field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(utils.ToString(data)) + field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: - field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return err } @@ -732,41 +784,49 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil && *data != nil { + field.Set(ctx, value, *data) + } case time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { - field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { - field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: - field.Set = func(value reflect.Value, v interface{}) error { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { + case **time.Time: + if data != nil { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) + } case time.Time: - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil @@ -778,27 +838,27 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fallbackSetter(value, v, field.Set) + return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { - fieldValue := field.ReflectValueOf(value) + fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -813,32 +873,80 @@ func (field *Field) setupValuerAndSetter() { } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().AssignableTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV) + field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { - return field.Set(value, reflectV.Elem().Interface()) + return field.Set(ctx, value, reflectV.Elem().Interface()) } } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { - field.Set = func(value reflect.Value, v interface{}) (err error) { - return fallbackSetter(value, v, field.Set) + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + return fallbackSetter(ctx, value, v, field.Set) } } } } + + if field.Serializer != nil { + var ( + oldFieldSetter = field.Set + sameElemType bool + sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() + ) + + if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { + sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() + } + + field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { + if s, ok := v.(*serializer); ok { + if s.fieldValue != nil { + err = oldFieldSetter(ctx, value, s.fieldValue) + } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { + if sameElemType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } else if sameType { + field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) + s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) + } + } + } else { + err = oldFieldSetter(ctx, value, v) + } + return + } + } +} + +func (field *Field) setupNewValuePool() { + if field.Serializer != nil { + field.NewValuePool = &sync.Pool{ + New: func() interface{} { + return &serializer{ + Field: field, + Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + } + }, + } + } + + if field.NewValuePool == nil { + field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType)) + } } diff --git a/schema/field_test.go b/schema/field_test.go index 8fa46b87..300e375b 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "database/sql" "reflect" "sync" @@ -57,7 +58,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -80,7 +81,7 @@ func TestFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -132,7 +133,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -151,7 +152,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -202,7 +203,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } @@ -219,7 +220,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { } for k, v := range newValues2 { - if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } diff --git a/schema/index.go b/schema/index.go index 5f775f30..5003c742 100644 --- a/schema/index.go +++ b/schema/index.go @@ -1,6 +1,7 @@ package schema import ( + "fmt" "sort" "strconv" "strings" @@ -31,7 +32,12 @@ func (schema *Schema) ParseIndexes() map[string]Index { for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { - for _, index := range parseFieldIndexes(field) { + fieldIndexes, err := parseFieldIndexes(field) + if err != nil { + schema.err = err + break + } + for _, index := range fieldIndexes { idx := indexes[index.Name] idx.Name = index.Name if idx.Class == "" { @@ -82,18 +88,19 @@ func (schema *Schema) LookIndex(name string) *Index { return nil } -func parseFieldIndexes(field *Field) (indexes []Index) { +func parseFieldIndexes(field *Field) (indexes []Index, err error) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { v := strings.Split(value, ":") k := strings.TrimSpace(strings.ToUpper(v[0])) if k == "INDEX" || k == "UNIQUEINDEX" { var ( - name string - tag = strings.Join(v[1:], ":") - idx = strings.Index(tag, ",") - settings = ParseTagSetting(tag, ",") - length, _ = strconv.Atoi(settings["LENGTH"]) + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") + settings = ParseTagSetting(tagSetting, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) ) if idx == -1 { @@ -105,7 +112,20 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } if name == "" { - name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + subName := field.Name + const key = "COMPOSITE" + if composite, found := settings[key]; found { + if len(composite) == 0 || composite == key { + err = fmt.Errorf( + "The composite tag of %s.%s cannot be empty", + field.Schema.Name, + field.Name) + return + } + subName = composite + } + name = field.Schema.namer.IndexName( + field.Schema.Table, subName) } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { @@ -137,5 +157,6 @@ func parseFieldIndexes(field *Field) (indexes []Index) { } } + err = nil return } diff --git a/schema/index_test.go b/schema/index_test.go index bc6bb8b6..1fe31cc1 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -18,6 +18,37 @@ type UserIndex struct { Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` + Name7 string `gorm:"index:type"` + + // Composite Index: Flattened structure. + Data0A string `gorm:"index:,composite:comp_id0"` + Data0B string `gorm:"index:,composite:comp_id0"` + + // Composite Index: Nested structure. + Data1A string `gorm:"index:,composite:comp_id1"` + CompIdxLevel1C + + // Composite Index: Unique and priority. + Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"` + CompIdxLevel2C +} + +type CompIdxLevel1C struct { + CompIdxLevel1B + Data1C string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel1B struct { + Data1B string `gorm:"index:,composite:comp_id1"` +} + +type CompIdxLevel2C struct { + CompIdxLevel2B + Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"` +} + +type CompIdxLevel2B struct { + Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"` } func TestParseIndex(t *testing.T) { @@ -78,6 +109,41 @@ func TestParseIndex(t *testing.T) { Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, }, + "type": { + Name: "type", + Type: "", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, + }, + "idx_user_indices_comp_id0": { + Name: "idx_user_indices_comp_id0", + Type: "", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data0A"}, + }, { + Field: &schema.Field{Name: "Data0B"}, + }}, + }, + "idx_user_indices_comp_id1": { + Name: "idx_user_indices_comp_id1", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data1A"}, + }, { + Field: &schema.Field{Name: "Data1B"}, + }, { + Field: &schema.Field{Name: "Data1C"}, + }}, + }, + "idx_user_indices_comp_id2": { + Name: "idx_user_indices_comp_id2", + Class: "UNIQUE", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Data2C"}, + }, { + Field: &schema.Field{Name: "Data2A"}, + }, { + Field: &schema.Field{Name: "Data2B"}, + }}, + }, } indices := user.ParseIndexes() diff --git a/schema/interfaces.go b/schema/interfaces.go index 98abffbd..a75a33c0 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,22 +4,33 @@ import ( "gorm.io/gorm/clause" ) +// GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string } +// FieldNewValuePool field new scan value pool +type FieldNewValuePool interface { + Get() interface{} + Put(interface{}) +} + +// CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface } +// QueryClausesInterface query clauses interface type QueryClausesInterface interface { QueryClauses(*Field) []clause.Interface } +// UpdateClausesInterface update clauses interface type UpdateClausesInterface interface { UpdateClauses(*Field) []clause.Interface } +// DeleteClausesInterface delete clauses interface type DeleteClausesInterface interface { DeleteClauses(*Field) []clause.Interface } diff --git a/schema/naming.go b/schema/naming.go index 8407bffa..a258beed 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -3,7 +3,6 @@ package schema import ( "crypto/sha1" "encoding/hex" - "fmt" "regexp" "strings" "unicode/utf8" @@ -86,16 +85,16 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formattedName := strings.Replace(strings.Join([]string{ + formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, - }, "_"), ".", "_", -1) + }, "_"), ".", "_") if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] } return formattedName } @@ -120,7 +119,13 @@ func (ns NamingStrategy) toDBName(name string) string { } if ns.NameReplacer != nil { - name = ns.NameReplacer.Replace(name) + tmpName := ns.NameReplacer.Replace(name) + + if tmpName == "" { + return name + } + + name = tmpName } if ns.NoLowerCase { @@ -168,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } diff --git a/schema/naming_test.go b/schema/naming_test.go index c3e6bf92..3f598c33 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -193,7 +193,18 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { ns := NamingStrategy{} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") - if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } + +func TestReplaceEmptyTableName(t *testing.T) { + ns := NamingStrategy{ + SingularTable: true, + NameReplacer: strings.NewReplacer("Model", ""), + } + tableName := ns.TableName("Model") + if tableName != "Model" { + t.Errorf("invalid table name generated, got %v", tableName) + } +} diff --git a/schema/pool.go b/schema/pool.go new file mode 100644 index 00000000..fa62fe22 --- /dev/null +++ b/schema/pool.go @@ -0,0 +1,19 @@ +package schema + +import ( + "reflect" + "sync" +) + +// sync pools +var ( + normalPool sync.Map + poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { + v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ + New: func() interface{} { + return reflect.New(reflectType).Interface() + }, + }) + return v.(FieldNewValuePool) + } +) diff --git a/schema/relationship.go b/schema/relationship.go index c5d3dcad..0aa33e51 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -1,6 +1,7 @@ package schema import ( + "context" "fmt" "reflect" "strings" @@ -234,7 +235,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -257,7 +259,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -415,6 +418,10 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } else { var primaryFields []*Field + var primarySchemaName = primarySchema.Name + if primarySchemaName == "" { + primarySchemaName = relation.FieldSchema.Name + } if len(relation.primaryKeys) > 0 { for _, primaryKey := range relation.primaryKeys { @@ -427,7 +434,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } for _, primaryField := range primaryFields { - lookUpName := primarySchema.Name + primaryField.Name + lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } @@ -576,7 +583,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { return &constraint } -func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { +func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} @@ -616,7 +623,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds [] } } - _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) diff --git a/schema/relationship_test.go b/schema/relationship_test.go index e2cf11a9..6fffbfcb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -491,6 +491,26 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestVariableRelation(t *testing.T) { + var result struct { + User + } + + checkStructRelation(t, &result, Relation{ + Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account", + References: []Reference{ + {"ID", "", "UserID", "Account", "", true}, + }, + }) + + checkStructRelation(t, &result, Relation{ + Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company", + References: []Reference{ + {"ID", "Company", "CompanyID", "", "", false}, + }, + }) +} + func TestSameForeignKey(t *testing.T) { type UserAux struct { gorm.Model @@ -576,3 +596,39 @@ func TestHasManySameForeignKey(t *testing.T) { References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } + +type Author struct { + gorm.Model +} + +type Book struct { + gorm.Model + Author Author + AuthorID uint +} + +func (Book) TableName() string { + return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" +} + +func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { + s, err := schema.Parse( + &Book{}, + &sync.Map{}, + schema.NamingStrategy{}, + ) + if err != nil { + t.Fatalf("Failed to parse schema") + } + + expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" + constraint := s.Relationships.Relations["Author"].ParseConstraint() + + if constraint.Name != expectedConstraintName { + t.Fatalf( + "expected constraint name %s, got %s", + expectedConstraintName, + constraint.Name, + ) + } +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 6d2bc664..9abaecba 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "context" "fmt" "reflect" "strings" @@ -203,7 +204,7 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - fv, _ := s.FieldsByDBName[k].ValueOf(value) + fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } diff --git a/schema/serializer.go b/schema/serializer.go new file mode 100644 index 00000000..758a6421 --- /dev/null +++ b/schema/serializer.go @@ -0,0 +1,157 @@ +package schema + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/gob" + "encoding/json" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var serializerMap = sync.Map{} + +// RegisterSerializer register serializer +func RegisterSerializer(name string, serializer SerializerInterface) { + serializerMap.Store(strings.ToLower(name), serializer) +} + +// GetSerializer get serializer +func GetSerializer(name string) (serializer SerializerInterface, ok bool) { + v, ok := serializerMap.Load(strings.ToLower(name)) + if ok { + serializer, ok = v.(SerializerInterface) + } + return serializer, ok +} + +func init() { + RegisterSerializer("json", JSONSerializer{}) + RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) +} + +// Serializer field value serializer +type serializer struct { + Field *Field + Serializer SerializerInterface + SerializeValuer SerializerValuerInterface + Destination reflect.Value + Context context.Context + value interface{} + fieldValue interface{} +} + +// Scan implements sql.Scanner interface +func (s *serializer) Scan(value interface{}) error { + s.value = value + return nil +} + +// Value implements driver.Valuer interface +func (s serializer) Value() (driver.Value, error) { + return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) +} + +// SerializerInterface serializer interface +type SerializerInterface interface { + Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error + SerializerValuerInterface +} + +// SerializerValuerInterface serializer valuer interface +type SerializerValuerInterface interface { + Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) +} + +// JSONSerializer json serializer +type JSONSerializer struct { +} + +// Scan implements serializer interface +func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) + } + + err = json.Unmarshal(bytes, fieldValue.Interface()) + } + + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + result, err := json.Marshal(fieldValue) + return string(result), err +} + +// UnixSecondSerializer json serializer +type UnixSecondSerializer struct { +} + +// Scan implements serializer interface +func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + t := sql.NullTime{} + if err = t.Scan(dbValue); err == nil && t.Valid { + err = field.Set(ctx, dst, t.Time.Unix()) + } + + return +} + +// Value implements serializer interface +func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + switch v := fieldValue.(type) { + case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) + default: + err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) + } + return +} + +// GobSerializer gob serializer +type GobSerializer struct { +} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/schema/utils.go b/schema/utils.go index e005cc74..acf1a739 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -1,6 +1,8 @@ package schema import ( + "context" + "fmt" "reflect" "regexp" "strings" @@ -58,14 +60,22 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct return tag } +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + // GetRelationsValues get relations's values from a reflect value -func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { +func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) @@ -97,7 +107,7 @@ func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (refle } // GetIdentityFieldValuesMap get identity map from fields -func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} @@ -110,7 +120,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { - results[0][idx], zero = field.ValueOf(reflectValue) + results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } @@ -135,7 +145,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { - fieldValues[idx], zero = field.ValueOf(elem) + fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } @@ -155,12 +165,12 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map } // GetIdentityFieldValuesMapFromValues get identity map from fields -func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { +func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { - rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } diff --git a/soft_delete.go b/soft_delete.go index 4582161d..6d646288 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -135,7 +133,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -143,7 +141,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/statement.go b/statement.go index 146722a9..ed3e8716 100644 --- a/statement.go +++ b/statement.go @@ -130,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.QuoteTo(writer, d) } @@ -143,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { writer.WriteByte('(') for idx, d := range v { if idx > 0 { - writer.WriteString(",") + writer.WriteByte(',') } stmt.DB.Dialector.QuoteTo(writer, d) } @@ -179,9 +179,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } - case clause.Expr: - v.Build(stmt) - case *clause.Expr: + case clause.Expression: v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) @@ -314,6 +312,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: + for _, scope := range v.Statement.scopes { + v = scope(v) + } + if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { @@ -391,7 +393,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -405,7 +407,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { @@ -564,7 +566,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . switch destValue.Kind() { case reflect.Struct: - field.Set(destValue, value) + stmt.AddError(field.Set(stmt.Context, destValue, value)) default: stmt.AddError(ErrInvalidData) } @@ -574,10 +576,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.ReflectValue.Index(i), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) } } else { - field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { @@ -585,7 +587,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . return } - field.Set(stmt.ReflectValue, value) + stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) } } else { stmt.AddError(ErrInvalidField) @@ -605,12 +607,12 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { - fieldValue, _ := field.ValueOf(modelValue) + fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if v, ok := stmt.Dest.(map[string]interface{}); ok { - if fv, ok := v[field.Name]; ok { + if mv, mok := stmt.Dest.(map[string]interface{}); mok { + if fv, ok := mv[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) - } else if fv, ok := v[field.DBName]; ok { + } else if fv, ok := mv[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) } } else { @@ -619,7 +621,10 @@ func (stmt *Statement) Changed(fields ...string) bool { destValue = destValue.Elem() } - changedValue, zero := field.ValueOf(destValue) + changedValue, zero := field.ValueOf(stmt.Context, destValue) + if v { + return !utils.AssertEqual(changedValue, fieldValue) + } return !zero && !utils.AssertEqual(changedValue, fieldValue) } } diff --git a/tests/associations_test.go b/tests/associations_test.go index 5ce98c7d..e729e979 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -220,3 +220,67 @@ func TestFullSaveAssociations(t *testing.T) { t.Errorf("Failed to preload AppliesToProduct") } } + +func TestSaveBelongsCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent} + DB.Create(&child) + + parent.FavChildID = child.ID + parent.FavChild = &child + DB.Save(&parent) + + var parent1 Parent + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") + + // Save and Updates is the same + DB.Updates(&parent) + DB.First(&parent1, parent.ID) + AssertObjEqual(t, parent, parent1, "ID", "FavChildID") +} + +func TestSaveHasManyCircularReference(t *testing.T) { + parent := Parent{} + DB.Create(&parent) + + child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} + child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} + + parent.Children = []*Child{&child, &child1} + DB.Save(&parent) + + var children []*Child + DB.Where("parent_id = ?", parent.ID).Find(&children) + if len(children) != len(parent.Children) || + children[0].ID != parent.Children[0].ID || + children[1].ID != parent.Children[1].ID { + t.Errorf("circular reference children save not equal children:%v parent.Children:%v", + children, parent.Children) + } +} + +func TestAssociationError(t *testing.T) { + user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) + DB.Create(&user) + + var user1 User + DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1) + + var emptyUser User + var err error + // belongs to + err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has many + err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // has one + err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) + // many to many + err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) + AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 02765b8c..2bf9496b 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -38,6 +38,7 @@ func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} +func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { @@ -168,3 +169,37 @@ func TestCallbacks(t *testing.T) { } } } + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} diff --git a/tests/connection_test.go b/tests/connection_test.go index 92b13dd6..7bc23009 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -9,7 +9,7 @@ import ( ) func TestWithSingleConnection(t *testing.T) { - var expectedName = "test" + expectedName := "test" var actualName string setSQL, getSQL := getSetSQL(DB.Dialector.Name()) @@ -27,7 +27,6 @@ func TestWithSingleConnection(t *testing.T) { } return nil }) - if err != nil { t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err)) } diff --git a/tests/connpool_test.go b/tests/connpool_test.go new file mode 100644 index 00000000..fbae2294 --- /dev/null +++ b/tests/connpool_test.go @@ -0,0 +1,171 @@ +package tests_test + +import ( + "context" + "database/sql" + "os" + "reflect" + "testing" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type wrapperTx struct { + *sql.Tx + conn *wrapperConnPool +} + +func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.PrepareContext(ctx, query) +} + +func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.ExecContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryContext(ctx, query, args...) +} + +func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.conn.got = append(c.conn.got, query) + return c.Tx.QueryRowContext(ctx, query, args...) +} + +type wrapperConnPool struct { + db *sql.DB + got []string + expect []string +} + +func (c *wrapperConnPool) Ping() error { + return c.db.Ping() +} + +// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { + tx, err := c.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &wrapperTx{Tx: tx, conn: c}, nil +} + +func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + c.got = append(c.got, query) + return c.db.PrepareContext(ctx, query) +} + +func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + c.got = append(c.got, query) + return c.db.ExecContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + c.got = append(c.got, query) + return c.db.QueryContext(ctx, query, args...) +} + +func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + c.got = append(c.got, query) + return c.db.QueryRowContext(ctx, query, args...) +} + +func TestConnPoolWrapper(t *testing.T) { + dialect := os.Getenv("GORM_DIALECT") + if dialect != "mysql" { + t.SkipNow() + } + + dbDSN := os.Getenv("GORM_DSN") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + nativeDB, err := sql.Open("mysql", dbDSN) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + conn := &wrapperConnPool{ + db: nativeDB, + expect: []string{ + "SELECT VERSION()", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + }, + } + + defer func() { + if !reflect.DeepEqual(conn.got, conn.expect) { + t.Errorf("expect %#v but got %#v", conn.expect, conn.got) + } + }() + + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + if err != nil { + t.Fatalf("Should open db success, but got %v", err) + } + + tx := db.Begin() + user := *GetUser("transaction", Config{}) + + if err = tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err = tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := db.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err = tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} diff --git a/tests/count_test.go b/tests/count_test.go index 27d7ee60..b71e3de5 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -144,4 +144,22 @@ func TestCount(t *testing.T) { if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) } + + var count12 int64 + if err := DB.Table("users"). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count12).Error; err == nil { + t.Errorf("error should raise when using preload without schema") + } + + var count13 int64 + if err := DB.Model(User{}). + Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). + Preload("Toys", func(db *gorm.DB) *gorm.DB { + return db.Table("toys").Select("name") + }).Count(&count13).Error; err != nil { + t.Errorf("no error should raise when using count with preload, but got %v", err) + } } diff --git a/tests/create_test.go b/tests/create_test.go index af2abdb0..3730172f 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) { {"name": "create_from_map_3", "Age": 20}, } - if err := DB.Model(&User{}).Create(datas).Error; err != nil { + if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } @@ -526,3 +526,17 @@ func TestCreateNilPointer(t *testing.T) { t.Fatalf("it is not ErrInvalidValue") } } + +func TestFirstOrCreateRowsAffected(t *testing.T) { + user := User{Name: "TestFirstOrCreateRowsAffected"} + + res := DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 1 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } + + res = DB.FirstOrCreate(&user, "name = ?", user.Name) + if res.Error != nil || res.RowsAffected != 0 { + t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) + } +} diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 5e00b154..918f0796 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "gorm.io/gorm" ) @@ -9,12 +10,13 @@ import ( func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model - Email string `gorm:"not null;index:,unique"` - Name string `gorm:"notNull;default:foo"` - Name2 string `gorm:"size:233;not null;default:'foo'"` - Name3 string `gorm:"size:233;notNull;default:''"` - Age int `gorm:"default:18"` - Enabled bool `gorm:"default:true"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Created time.Time `gorm:"default:2000-01-02"` + Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) @@ -26,14 +28,14 @@ func TestDefaultValue(t *testing.T) { harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) - } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) - } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to find created data with default data, got %+v", result) } } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 05e0956e..9ab4ddb6 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: mysql: - image: 'mysql:latest' + image: 'mysql/mysql-server:latest' ports: - 9910:3306 environment: @@ -20,7 +20,7 @@ services: - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: - image: 'mcmoe/mssqldocker:latest' + image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - 9930:1433 environment: diff --git a/tests/go.mod b/tests/go.mod index 3233ea95..6a2cf22f 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,16 +3,16 @@ module gorm.io/gorm/tests go 1.14 require ( + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 - github.com/jackc/pgx/v4 v4.14.1 // indirect - github.com/jinzhu/now v1.1.4 - github.com/lib/pq v1.10.4 - golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b // indirect - gorm.io/driver/mysql v1.2.3 - gorm.io/driver/postgres v1.2.3 - gorm.io/driver/sqlite v1.2.6 - gorm.io/driver/sqlserver v1.2.1 - gorm.io/gorm v1.22.4 + github.com/jinzhu/now v1.1.5 + github.com/lib/pq v1.10.5 + golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect + gorm.io/driver/mysql v1.3.3 + gorm.io/driver/postgres v1.3.5 + gorm.io/driver/sqlite v1.3.2 + gorm.io/driver/sqlserver v1.3.2 + gorm.io/gorm v1.23.4 ) replace gorm.io/gorm => ../ diff --git a/tests/helper_test.go b/tests/helper_test.go index eee34e99..7ee2a576 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -19,6 +19,7 @@ type Config struct { Team int Languages int Friends int + NamedPet bool } func GetUser(name string, config Config) *User { @@ -65,6 +66,10 @@ func GetUser(name string, config Config) *User { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } + if config.NamedPet { + user.NamedPet = &Pet{Name: name + "_namepet"} + } + return &user } diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0e6ab2fe..8e964fd8 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -375,13 +375,19 @@ func TestSetColumn(t *testing.T) { t.Errorf("invalid data after update, got %+v", product) } + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""}) + if product.Name != "Product New4" || product.Price != 320 || product.Code != "" { + t.Errorf("invalid data after update, got %+v", product) + } + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) - if product.Price != 270 || product.Code != "L1215" { + if product.Price != 320 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) - if product.Price != 270 || product.Code != "L1216" { + if product.Price != 320 || product.Code != "L1216" { t.Errorf("invalid data after update, got %+v", product) } @@ -460,8 +466,9 @@ type Product4 struct { type ProductItem struct { gorm.Model - Code string - Product4ID uint + Code string + Product4ID uint + AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { @@ -471,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error { return nil } +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) @@ -492,4 +504,13 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } } diff --git a/tests/joins_test.go b/tests/joins_test.go index e276a74a..4908e5ba 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -10,12 +10,12 @@ import ( ) func TestJoins(t *testing.T) { - user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) DB.Create(&user) var user2 User - if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } @@ -158,6 +158,22 @@ func TestJoinsWithSelect(t *testing.T) { } } +func TestJoinWithOmit(t *testing.T) { + user := *GetUser("joins_with_omit", Config{Pets: 2}) + DB.Save(&user) + + results := make([]*User, 0) + + if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { + return + } + + if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { + t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) + return + } +} + func TestJoinCount(t *testing.T) { companyA := Company{Name: "A"} companyB := Company{Name: "B"} @@ -184,3 +200,32 @@ func TestJoinCount(t *testing.T) { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } } + +func TestJoinWithSoftDeleted(t *testing.T) { + user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) + DB.Create(&user) + + var user1 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID) + if user1.NamedPet == nil || user1.Account.ID == 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user1) + } + + // Account should empty + DB.Delete(&user1.Account) + + var user2 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID) + if user2.NamedPet == nil || user2.Account.ID != 0 { + t.Fatalf("joins Account should not empty:%v", user2) + } + + // NamedPet should empty + DB.Delete(&user1.NamedPet) + + var user3 User + DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID) + if user3.NamedPet != nil || user2.Account.ID != 0 { + t.Fatalf("joins NamedPet and Account should not empty:%v", user2) + } +} diff --git a/tests/main_test.go b/tests/main_test.go index 5b8c7dbb..997714b9 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -43,10 +43,8 @@ func TestExceptionsWithInvalidSql(t *testing.T) { func TestSetAndGet(t *testing.T) { if value, ok := DB.Set("hello", "world").Get("hello"); !ok { t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } + } else if value.(string) != "world" { + t.Errorf("Set value should not be changed") } if _, ok := DB.Get("non_existing"); ok { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index aa0a84ab..f862eda0 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) { for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v---", m) + t.Fatalf("Failed to create table for %#v", m) } } @@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { - fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint @@ -258,7 +258,7 @@ func TestMigrateTable(t *testing.T) { DB.Migrator().DropTable("new_table_structs") if DB.Migrator().HasTable(&NewTableStruct{}) { - t.Fatal("should not found droped table") + t.Fatal("should not found dropped table") } } @@ -313,9 +313,16 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { + sqlite := DB.Dialector.Name() == "sqlite" + sqlserver := DB.Dialector.Name() == "sqlserver" + type ColumnStruct struct { gorm.Model - Name string + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -326,13 +333,20 @@ func TestMigrateColumns(t *testing.T) { type ColumnStruct2 struct { gorm.Model - Name string `gorm:"size:100"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string } - if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { @@ -340,11 +354,45 @@ func TestMigrateColumns(t *testing.T) { stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { - if columnType.Name() == "name" { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } } } } @@ -526,3 +574,122 @@ func TestMigrateColumnOrder(t *testing.T) { } } } + +// https://github.com/go-gorm/gorm/issues/5047 +func TestMigrateSerialColumn(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Event struct { + ID uint `gorm:"primarykey"` + UID uint32 + } + + type Event1 struct { + ID uint `gorm:"primarykey"` + UID uint32 `gorm:"not null;autoIncrement"` + } + + type Event2 struct { + ID uint `gorm:"primarykey"` + UID uint16 `gorm:"not null;autoIncrement"` + } + + var err error + err = DB.Migrator().DropTable(&Event{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + // create sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // delete sequence + err = DB.Table("events").AutoMigrate(&Event{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // update sequence + err = DB.Table("events").AutoMigrate(&Event1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = DB.Table("events").AutoMigrate(&Event2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + DB.Table("events").Save(&Event2{}) + + events := make([]*Event, 0) + DB.Table("events").Find(&events) + + AssertEqual(t, 3, len(events)) + for _, v := range events { + AssertEqual(t, v.ID, v.UID) + } +} + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85671864..66b988c3 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "github.com/google/uuid" "github.com/lib/pq" @@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model - Name string `gorm:"check:name_checker,name <> ''"` - Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` - Things pq.StringArray `gorm:"type:text[]"` + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { @@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) { if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } + + harumph.Name = "jinzhu1" + if err := DB.Save(&harumph).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } } type Post struct { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 0ef8890b..b5b6a70f 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1335,7 +1335,7 @@ func TestNilPointerSlice(t *testing.T) { } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { - t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { diff --git a/tests/preload_test.go b/tests/preload_test.go index adb54ee1..cb4343ec 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -251,3 +251,21 @@ func TestPreloadGoroutine(t *testing.T) { } wg.Wait() } + +func TestPreloadWithDiffModel(t *testing.T) { + user := *GetUser("preload_with_diff_model", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var result struct { + Something string + User + } + + DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( + "users.*, 'yo' as something").First(&result, "name = ?", user.Name) + + CheckUser(t, user, result.User) +} diff --git a/tests/query_test.go b/tests/query_test.go index c99214b6..f66cf83a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -292,6 +292,68 @@ func TestFindInBatches(t *testing.T) { } } +func TestFindInBatchesWithOffsetLimit(t *testing.T) { + users := []User{ + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + *GetUser("find_in_batches_with_offset_limit", Config{}), + } + + DB.Create(&users) + + var ( + sub, results []User + lastBatch int + ) + + // offset limit + if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { + results = append(results, sub...) + lastBatch = batch + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + if lastBatch != 3 { + t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) + } + + targetUsers := users[3:8] + for i := 0; i < len(targetUsers); i++ { + AssertEqual(t, results[i], targetUsers[i]) + } + + var sub1 []User + // limit < batchSize + if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 5 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub2 []User + // only offset + if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 7 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + var sub3 []User + if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { + return nil + }); result.Error != nil || result.RowsAffected != 4 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } +} + func TestFindInBatchesWithError(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql") @@ -512,7 +574,13 @@ func TestNotWithAllFields(t *testing.T) { func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + var count int64 + result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) + if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } @@ -577,7 +645,9 @@ func TestPluck(t *testing.T) { if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { t.Errorf("got error when pluck name: %v", err) } - AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + + sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] }) + AssertEqual(t, names, names2) var ids []int if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { @@ -1152,3 +1222,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +type DoubleInt64 struct { + data int64 +} + +func (t *DoubleInt64) Scan(val interface{}) error { + switch v := val.(type) { + case int64: + t.data = v * 2 + return nil + default: + return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) + } +} + +// https://github.com/go-gorm/gorm/issues/5091 +func TestQueryScannerWithSingleColumn(t *testing.T) { + user := User{Name: "scanner_raw_1", Age: 10} + DB.Create(&user) + + var result1 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( + "age", &result1).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result1.data, 20) + + var result2 DoubleInt64 + if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( + "age").Scan(&result2).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + AssertEqual(t, result2.data, 20) +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 1a188fac..425c0a29 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -10,6 +10,11 @@ import ( . "gorm.io/gorm/utils/tests" ) +type PersonAddressInfo struct { + Person *Person `gorm:"embedded"` + Address *Address `gorm:"embedded"` +} + func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} @@ -156,3 +161,57 @@ func TestScanRows(t *testing.T) { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) } } + +func TestScanToEmbedded(t *testing.T) { + person1 := Person{Name: "person 1"} + person2 := Person{Name: "person 2"} + DB.Save(&person1).Save(&person2) + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + DB.Save(&address1).Save(&address2) + + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) + DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) + DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) + + var personAddressInfoList []*PersonAddressInfo + if err := DB.Select("people.*, addresses.*"). + Table("people"). + Joins("inner join person_addresses on people.id = person_addresses.person_id"). + Joins("inner join addresses on person_addresses.address_id = addresses.id"). + Find(&personAddressInfoList).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + + personMatched := false + addressMatched := false + + for _, info := range personAddressInfoList { + if info.Person == nil { + t.Fatalf("Failed, expected not nil, got person nil") + } + if info.Address == nil { + t.Fatalf("Failed, expected not nil, got address nil") + } + if info.Person.ID == person1.ID { + personMatched = true + if info.Person.Name != person1.Name { + t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) + } + } + if info.Address.ID == address1.ID { + addressMatched = true + if info.Address.Name != address1.Name { + t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name) + } + } + } + + if !personMatched { + t.Errorf("Failed, no person matched") + } + if !addressMatched { + t.Errorf("Failed, no address matched") + } +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go new file mode 100644 index 00000000..ee14841a --- /dev/null +++ b/tests/serializer_test.go @@ -0,0 +1,138 @@ +package tests_test + +import ( + "bytes" + "context" + "fmt" + "reflect" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + . "gorm.io/gorm/utils/tests" +) + +type SerializerStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type + EncryptedString EncryptedString +} + +type Roles []string + +type Job struct { + Title string + Number int + Location string + IsIntern bool +} + +type EncryptedString string + +func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) + case string: + *es = EncryptedString(strings.TrimPrefix(value, "hello")) + default: + return fmt.Errorf("unsupported data %#v", dbValue) + } + return nil +} + +func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return "hello" + string(es), nil +} + +func TestSerializer(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt := createdAt.Unix() + + data := SerializerStruct{ + Name: []byte("jinzhu"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + UpdatedTime: &updatedAt, + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) + +} + +func TestSerializerAssignFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&SerializerStruct{}) + if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + + data := SerializerStruct{ + Name: []byte("ag9920"), + Roles: []string{"r1", "r2"}, + Contracts: map[string]interface{}{"name": "jing1", "age": 11}, + EncryptedString: EncryptedString("pass"), + CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Shadyside", + IsIntern: false, + }, + } + + // first time insert record + out := SerializerStruct{} + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + + var result SerializerStruct + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + AssertEqual(t, result, out) + + //update record + data.Roles = append(data.Roles, "r3") + data.JobInfo.Location = "Gates Hillman Complex" + if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { + t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) + } + if err := DB.First(&result, out.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result.Roles, data.Roles) + AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 237d807b..a9b920dc 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -168,6 +168,59 @@ func TestDryRun(t *testing.T) { } } +type ageInt int8 + +func (ageInt) String() string { + return "age" +} + +type ageBool bool + +func (ageBool) String() string { + return "age" +} + +type ageUint64 uint64 + +func (ageUint64) String() string { + return "age" +} + +type ageFloat float64 + +func (ageFloat) String() string { + return "age" +} + +func TestExplainSQL(t *testing.T) { + user := *GetUser("explain-sql", Config{}) + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement + sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } + + stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement + sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { + t.Errorf("Failed to generate sql, got %v", sql) + } +} + func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint @@ -190,6 +243,21 @@ func TestGroupConditions(t *testing.T) { if !strings.HasSuffix(result, expects) { t.Errorf("expects: %v, got %v", expects, result) } + + stmt2 := dryRunDB.Where( + DB.Scopes(NameIn1And2), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement + + result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...) + expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...) + + if !strings.HasSuffix(result2, expects2) { + t.Errorf("expects: %v, got %v", expects2, result2) + } } func TestCombineStringConditions(t *testing.T) { @@ -307,7 +375,7 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } @@ -373,13 +441,13 @@ func TestToSQL(t *testing.T) { }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) - // after model chagned + // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } } -// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect speicals. +// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. func assertEqualSQL(t *testing.T, expected string, actually string) { t.Helper() @@ -387,7 +455,7 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { expected = replaceQuoteInSQL(expected) actually = replaceQuoteInSQL(actually) - // ignore updated_at value, becase it's generated in Gorm inernal, can't to mock value on update. + // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update. updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) @@ -407,16 +475,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { func replaceQuoteInSQL(sql string) string { // convert single quote into double quote - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) - // convert dialect speical quote into double quote + // convert dialect special quote into double quote switch DB.Dialector.Name() { case "postgres": - sql = strings.Replace(sql, `"`, `"`, -1) + sql = strings.ReplaceAll(sql, `"`, `"`) case "mysql", "sqlite": - sql = strings.Replace(sql, "`", `"`, -1) + sql = strings.ReplaceAll(sql, "`", `"`) case "sqlserver": - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) } return sql diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 79e0b5b7..5b9bae97 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -15,6 +15,25 @@ then cd .. fi +# SqlServer for Mac M1 +if [[ -z $GITHUB_ACTION ]]; then + if [ -d tests ] + then + cd tests + if [[ $(uname -a) == *" arm64" ]]; then + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true + SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true + else + docker-compose start + fi + cd .. + fi +fi + + for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then diff --git a/tests/tests_test.go b/tests/tests_test.go index e26f358d..08f4f193 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -62,13 +62,14 @@ func OpenTestConnection() (db *gorm.DB, err error) { PreferSimpleProtocol: true, }), &gorm.Config{}) case "sqlserver": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest + // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // CREATE DATABASE gorm; - // USE gorm; + // GO + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - // npm install -g sql-cli - // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 + // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; + // GO log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" @@ -94,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9b..41ea5d27 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index c5d19605..f90c4518 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -319,7 +319,7 @@ func TestUpdateWithMissWhere(t *testing.T) { tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) if err := tx.Error; err != nil { - t.Fatalf("failed to update user,missing where condtion,err=%+v", err) + t.Fatalf("failed to update user,missing where condition,err=%+v", err) } if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 9543f750..2990c20f 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -49,7 +49,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 - writer.WriteString("`") + writer.WriteByte('`') } writer.WriteByte(v) continue @@ -74,7 +74,7 @@ func (DummyDialector) QuoteTo(writer clause.Writer, str string) { if continuousBacktick > 0 && !selfQuoted { writer.WriteString("``") } - writer.WriteString("`") + writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { diff --git a/utils/tests/models.go b/utils/tests/models.go index c84f9cae..22e8e659 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -80,3 +80,17 @@ type Order struct { Coupon *Coupon CouponID string } + +type Parent struct { + gorm.Model + FavChildID uint + FavChild *Child + Children []*Child +} + +type Child struct { + gorm.Model + Name string + ParentID *uint + Parent *Parent +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 817e4b0b..661d727f 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -83,20 +83,22 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } if reflect.ValueOf(got).Kind() == reflect.Struct { - if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { - exported := false - for i := 0; i < reflect.ValueOf(got).NumField(); i++ { - if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { - exported = true - field := reflect.ValueOf(got).Field(i) - t.Run(fieldStruct.Name, func(t *testing.T) { - AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) - }) + if reflect.ValueOf(expect).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } } - } - if exported { - return + if exported { + return + } } } } @@ -107,6 +109,9 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() isEqual() + } else { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return } } } diff --git a/utils/utils.go b/utils/utils.go index f00f92ba..296917b9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -36,17 +36,14 @@ func IsValidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } -func CheckTruth(val interface{}) bool { - if v, ok := val.(bool); ok { - return v +// CheckTruth check string true or not +func CheckTruth(vals ...string) bool { + for _, val := range vals { + if val != "" && !strings.EqualFold(val, "false") { + return true + } } - - if v, ok := val.(string); ok { - v = strings.ToLower(v) - return v != "false" - } - - return !reflect.ValueOf(val).IsZero() + return false } func ToStringKey(values ...interface{}) string {