diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index aa1812d4..fbebfc12 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 0e8aaa60..ef852765 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -11,7 +11,7 @@ jobs: name: Label issues and pull requests steps: - name: check out - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: labeler uses: jinzhu/super-labeler-action@develop diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index c3c92beb..b23a5bf9 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index a6542d57..3a65f0bc 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -6,7 +6,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: golangci-lint uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af8d3636..c9752883 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b97da3f4..af471d20 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,21 +16,21 @@ jobs: sqlite: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -41,8 +41,8 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.18', '1.17', '1.16'] + dbversion: ['mysql:latest', 'mysql:5.7'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -65,16 +65,15 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 - + uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -82,11 +81,54 @@ jobs: - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + mariadb: + strategy: + matrix: + dbversion: [ 'mariadb:latest' ] + go: ['1.21', '1.20', '1.19'] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mariadb-admin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v4 + + - name: go mod package cache + uses: actions/cache@v4 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + postgres: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.18', '1.17', '1.16'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -109,15 +151,15 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} @@ -128,7 +170,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.21', '1.20', '1.19'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} @@ -152,18 +194,51 @@ jobs: steps: - name: Set up Go 1.x - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: go mod package cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + + tidb: + strategy: + matrix: + dbversion: [ 'v6.5.0' ] + go: ['1.21', '1.20', '1.19'] + platform: [ ubuntu-latest ] + runs-on: ${{ matrix.platform }} + + steps: + - name: Setup TiDB + uses: Icemap/tidb-action@main + with: + port: 9940 + version: ${{matrix.dbversion}} + + - name: Set up Go 1.x + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v4 + + + - name: go mod package cache + uses: actions/cache@v4 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh diff --git a/.gitignore b/.gitignore index 45505cc9..72733326 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ documents coverage.txt _book .idea -vendor \ No newline at end of file +vendor +.vscode diff --git a/.golangci.yml b/.golangci.yml index 16903ed6..b88bf672 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,3 +9,12 @@ linters: - prealloc - unconvert - unparam + - goimports + - whitespace + +linters-settings: + whitespace: + multi-func: true + goimports: + local-prefixes: gorm.io/gorm + diff --git a/License b/LICENSE similarity index 100% rename from License rename to LICENSE diff --git a/README.md b/README.md index 312a3a59..745dad60 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,6 @@ The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) -[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) -[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) @@ -30,14 +27,18 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) -* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) +* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) +## Contributors + +[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! + ## License © Jinzhu, 2013~time.Now -Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) diff --git a/association.go b/association.go index 35e10ddd..7c93ebea 100644 --- a/association.go +++ b/association.go @@ -14,6 +14,7 @@ import ( type Association struct { DB *DB Relationship *schema.Relationship + Unscope bool Error error } @@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association { return association } +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, + } +} + func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { association.Error = association.buildCondition().Find(out, conds...).Error @@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old associations's foreign key to null - reflectValue := association.DB.Statement.ReflectValue - rel := association.Relationship switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error { association.Error = association.DB.UpdateColumns(updateMap).Error } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } } case schema.Many2Many: var ( @@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.BelongsTo: - tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { @@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error { conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } case schema.HasOne, schema.HasMany: - tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { @@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field @@ -353,9 +404,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) } appendToFieldValues := func(ev reflect.Value) { @@ -507,7 +562,9 @@ func (association *Association) buildCondition() *DB { joinStmt.AddClause(queryClause) } joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + if len(joinStmt.SQL.String()) > 0 { + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } } tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ diff --git a/callbacks.go b/callbacks.go index f344649e..50b5b0e9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } + db = db.executeScopes() } var ( @@ -93,6 +89,10 @@ func (p *processor) Execute(db *DB) *DB { resetBuildClauses = true } + if optimizer, ok := db.Statement.Dest.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } + // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest @@ -132,7 +132,11 @@ func (p *processor) Execute(db *DB) *DB { if stmt.SQL.Len() > 0 { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { - return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + sql, vars := stmt.SQL.String(), stmt.Vars + if filter, ok := db.Logger.(ParamsFilter); ok { + sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...) + } + return db.Dialector.Explain(sql, vars...), db.RowsAffected }, db.Error) } @@ -183,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -245,8 +257,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) - sort.Slice(cs, func(i, j int) bool { - return cs[j].before == "*" || cs[j].after == "*" + sort.SliceStable(cs, func(i, j int) bool { + if cs[j].before == "*" && cs[i].before != "*" { + return true + } + if cs[j].after == "*" && cs[i].after != "*" { + return true + } + return false }) for _, c := range cs { @@ -329,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/callbacks/associations.go b/callbacks/associations.go index fd3141cf..f3cd464a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() != reflect.Struct { break } - if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) } } } if elems.Len() > 0 { - if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -206,9 +221,12 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } - cacheKey := utils.ToStringKey(relPrimaryValues) + cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { - identityMap[cacheKey] = true + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + if isPtr { elems = reflect.Append(elems, elem) } else { @@ -253,6 +271,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} @@ -272,19 +291,34 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { joins = reflect.Append(joins, joinValue) } + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) - for i := 0; i < f.Len(); i++ { elem := f.Index(i) - - objs = append(objs, v) - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + if !isPtr { + elem = elem.Addr() } + objs = append(objs, v) + elems = reflect.Append(elems, elem) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, elem) + } + } } } @@ -304,7 +338,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { - saveAssociations(db, rel, elems, selectColumns, restricted, nil) + saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index bcaa03f3..fb900037 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } db.Statement.CurDestIndex++ } case reflect.Struct: - fc(db.Statement.ReflectValue.Addr().Interface(), tx) + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } } } } diff --git a/callbacks/create.go b/callbacks/create.go index 0fe1dc93..afea2cca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -102,13 +103,62 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if db.RowsAffected != 0 && db.Statement.Schema != nil && - db.Statement.Schema.PrioritizedPrimaryField != nil && - db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { + if db.RowsAffected == 0 { + return + } + + var ( + pkField *schema.Field + pkFieldName = "@id" + ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { db.AddError(err) + } + return + } + + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + return + } + pkField = db.Statement.Schema.PrioritizedPrimaryField + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: + if pkField == nil { return } @@ -121,10 +171,10 @@ func Create(config *Config) func(db *gorm.DB) { break } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) + _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID -= pkField.AutoIncrementIncrement } } } else { @@ -134,16 +184,16 @@ func Create(config *Config) func(db *gorm.DB) { break } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { + db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) + insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) + _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { - db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) + db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } @@ -252,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } @@ -302,14 +354,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: - assignment.Value = curTime.UnixNano() / 1e6 + assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } diff --git a/callbacks/create_test.go b/callbacks/create_test.go new file mode 100644 index 00000000..da6b172b --- /dev/null +++ b/callbacks/create_test.go @@ -0,0 +1,71 @@ +package callbacks + +import ( + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +var schemaCache = &sync.Map{} + +func TestConvertToCreateValues_DestType_Slice(t *testing.T) { + type user struct { + ID int `gorm:"primaryKey"` + Name string + Email string `gorm:"default:(-)"` + Age int `gorm:"default:(-)"` + } + + s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) + if err != nil { + t.Errorf("parse schema error: %v, is not expected", err) + return + } + dest := []*user{ + { + ID: 1, + Name: "alice", + Email: "email", + Age: 18, + }, + { + ID: 2, + Name: "bob", + Email: "email", + Age: 19, + }, + } + stmt := &gorm.Statement{ + DB: &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { return time.Time{} }, + }, + Statement: &gorm.Statement{ + Settings: sync.Map{}, + Schema: s, + }, + }, + ReflectValue: reflect.ValueOf(dest), + Dest: dest, + } + + stmt.Schema = s + + values := ConvertToCreateValues(stmt) + expected := clause.Values{ + // column has value + defaultValue column has value (which should have a stable order) + Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, + Values: [][]interface{}{ + {"alice", "email", 18, 1}, + {"bob", "email", 19, 2}, + }, + } + if !reflect.DeepEqual(expected, values) { + t.Errorf("expected: %v got %v", expected, values) + } +} diff --git a/callbacks/helper_test.go b/callbacks/helper_test.go new file mode 100644 index 00000000..08f94e20 --- /dev/null +++ b/callbacks/helper_test.go @@ -0,0 +1,157 @@ +package callbacks + +import ( + "reflect" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} + +func TestConvertMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert string value", + input: map[string]interface{}{ + "name": "my name", + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert int value", + input: map[string]interface{}{ + "age": 18, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert float value", + input: map[string]interface{}{ + "score": 99.5, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert bool value", + input: map[string]interface{}{ + "active": true, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input) + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expect %v got %v", tc.expect, actual) + } + }) + } +} + +func TestConvertSliceOfMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input []map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert slice of string value", + input: []map[string]interface{}{ + {"name": "my name"}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert slice of int value", + input: []map[string]interface{}{ + {"age": 18}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert slice of float value", + input: []map[string]interface{}{ + {"score": 99.5}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert slice of bool value", + input: []map[string]interface{}{ + {"active": true}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input) + + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expected %v but got %v", tc.expect, actual) + } + }) + } + +} diff --git a/callbacks/preload.go b/callbacks/preload.go index ea2570ba..cf7a0d2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,6 +3,8 @@ package callbacks import ( "fmt" "reflect" + "sort" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -10,6 +12,164 @@ import ( "gorm.io/gorm/utils" ) +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. +// If the current relationship is embedded or joined, current query will be ignored. +// +//nolint:cyclop +func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { + preloadMap := parsePreloadMap(db.Statement.Schema, preloads) + + // avoid random traversal of the map + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) + } + sort.Strings(preloadNames) + + isJoined := func(name string) (joined bool, nestedJoins []string) { + for _, join := range joins { + if _, ok := relationships.Relations[join]; ok && name == join { + joined = true + continue + } + joinNames := strings.SplitN(join, ".", 2) + if len(joinNames) == 2 { + if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { + joined = true + nestedJoins = append(nestedJoins, joinNames[1]) + } + } + } + return joined, nestedJoins + } + + for _, name := range preloadNames { + if relations := relationships.EmbeddedRelations[name]; relations != nil { + if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if joined, nestedJoins := isJoined(name); joined { + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData + } + } else { + tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) + tx.Statement.ReflectValue = db.Statement.ReflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { + return err + } + } + } else { + return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) + } + } + return nil +} + +func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { + tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if err := tx.Statement.Parse(dest); err != nil { + tx.AddError(err) + return tx + } + tx.Statement.ReflectValue = reflectValue + tx.Statement.Unscoped = db.Statement.Unscoped + return tx +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/callbacks/query.go b/callbacks/query.go index fb2bb37a..2a82eaba 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -3,11 +3,12 @@ package callbacks import ( "fmt" "reflect" - "sort" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { @@ -109,78 +110,141 @@ func BuildQuerySQL(db *gorm.DB) { } } + specifiedRelationsName := make(map[string]interface{}) for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } - for _, s := range relation.FieldSchema.DBNames { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } + } + } + + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) + } + + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) + } + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } + + if join.On != nil { + onStmt.AddClause(join.On) + } + + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + } + } + } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } + } + + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } - - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, - } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, - } - } - } - } - - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } - - if join.On != nil { - onStmt.AddClause(join.On) - } - - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) - - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) - } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) - } - } - } - } - - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: clause.LeftJoin, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, @@ -189,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) { } db.Statement.AddClause(fromClause) - db.Statement.Joins = nil } else { db.Statement.AddClauseIfNotExists(clause.From{}) } @@ -207,60 +270,23 @@ func Preload(db *gorm.DB) { return } - preloadMap := map[string]map[string][]interface{}{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } - } + joins := make([]string, 0, len(db.Statement.Joins)) + for _, join := range db.Statement.Joins { + joins = append(joins, join.Name) } - preloadNames := make([]string, 0, len(preloadMap)) - for key := range preloadMap { - preloadNames = append(preloadNames, key) - } - sort.Strings(preloadNames) - - preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) - db.Statement.Settings.Range(func(k, v interface{}) bool { - preloadDB.Statement.Settings.Store(k, v) - return true - }) - - if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil { + tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) + if tx.Error != nil { return } - preloadDB.Statement.ReflectValue = db.Statement.ReflectValue - for _, name := range preloadNames { - if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { - db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) - } else { - db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) - } - } + db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } func AfterQuery(db *gorm.DB) { + // clear the joins after query because preload need it + db.Statement.Joins = nil if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/callbacks/row.go b/callbacks/row.go index 56be742e..beaa189e 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -7,7 +7,7 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) - if db.DryRun { + if db.DryRun || db.Error != nil { return } diff --git a/callbacks/update.go b/callbacks/update.go index 01f40509..7cde7f61 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,10 +70,13 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else if _, ok := db.Statement.Clauses["SET"]; !ok { - return + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") + db.Statement.AddClause(set) + } else { + return + } } db.Statement.Build(db.Statement.BuildClauses...) @@ -135,7 +138,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } } } case reflect.Struct: @@ -158,21 +163,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if size := stmt.ReflectValue.Len(); size > 0 { - var primaryKeyExprs []clause.Expression + var isZero bool for i := 0; i < size; i++ { - exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) - var notZero bool - for idx, field := range stmt.Schema.PrimaryFields { - value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) - exprs[idx] = clause.Eq{Column: field.DBName, Value: value} - notZero = notZero || !isZero - } - if notZero { - primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + for _, field := range stmt.Schema.PrimaryFields { + _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) + if !isZero { + break + } } } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) + if !isZero { + _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { @@ -229,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { @@ -241,11 +246,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } default: updatingSchema := stmt.Schema + var isDiffSchema bool if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} if err := updatingStmt.Parse(stmt.Dest); err == nil { updatingSchema = updatingStmt.Schema + isDiffSchema = true } } @@ -261,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 + value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { @@ -272,7 +279,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) } } } else { diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go deleted file mode 100644 index b1fb86db..00000000 --- a/callbacks/visit_map_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package callbacks - -import ( - "reflect" - "testing" -) - -func TestLoadOrStoreVisitMap(t *testing.T) { - var vm visitMap - var loaded bool - type testM struct { - Name string - } - - t1 := testM{Name: "t1"} - t2 := testM{Name: "t2"} - t3 := testM{Name: "t3"} - - vm = make(visitMap) - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { - t.Fatalf("loaded should be true") - } - - // t1 already exist but t2 not - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { - t.Fatalf("loaded should be true") - } -} diff --git a/chainable_api.go b/chainable_api.go index 68b4d1aa..1ec9b865 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -10,10 +10,11 @@ import ( ) // Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") +// +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` +// db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value @@ -21,6 +22,19 @@ func (db *DB) Model(value interface{}) (tx *DB) { } // Clauses Add clauses +// +// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more +// advanced techniques like specifying lock strength and optimizer hints. See the +// [docs] for more depth. +// +// // add a simple limit clause +// db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) +// // tell the optimizer to use the `idx_user_name` index +// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) +// // specify the lock strength to UPDATE +// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) +// +// [docs]: https://gorm.io/docs/sql_builder.html#Clauses func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} @@ -41,15 +55,22 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { return } -var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) +var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) // Table specify the table you would like to run db operations +// +// // Get a user +// db.Table("users").Take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} - if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { - tx.Statement.Table = results[1] + if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { + if results[1] != "" { + tx.Statement.Table = results[1] + } else { + tx.Statement.Table = results[2] + } } } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} @@ -65,6 +86,11 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { } // Distinct specify distinct fields that you want querying +// +// // Select distinct names of users +// db.Distinct("name").Find(&results) +// // Select distinct name/age pairs from users +// db.Distinct("name", "age").Find(&results) func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Distinct = true @@ -75,6 +101,14 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { } // Select specify fields that you want when querying, creating, updating +// +// Use Select when you only want a subset of the fields. By default, GORM will select all fields. +// Select accepts both string arguments and arrays. +// +// // Select name and age of user using multiple arguments +// db.Select("name", "age").Find(&users) +// // Select name and age of user using an array +// db.Select([]string{"name", "age"}).Find(&users) func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() @@ -152,6 +186,17 @@ func (db *DB) Omit(columns ...string) (tx *DB) { } // Where add conditions +// +// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. +// +// // Find the first user with name jinzhu +// db.Where("name = ?", "jinzhu").First(&user) +// // Find the first user with name jinzhu and age 20 +// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) +// // Find the first user with name jinzhu and age not equal to 20 +// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) +// +// [docs]: https://gorm.io/docs/query.html#Conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -161,6 +206,11 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { } // Not add NOT conditions +// +// Not works similarly to where, and has the same syntax. +// +// // Find the first user with name not equal to jinzhu +// db.Not("name = ?", "jinzhu").First(&user) func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -170,6 +220,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { } // Or add OR conditions +// +// Or is used to chain together queries with an OR. +// +// // Find the first user with name equal to jinzhu or john +// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { @@ -179,26 +234,45 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { } // Joins specify Joins conditions -// db.Joins("Account").Find(&user) -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) +// +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.LeftJoin, query, args...) +} + +// InnerJoins specify inner joins conditions +// db.InnerJoins("Account").Find(&user) +func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { + return joins(db, clause.InnerJoin, query, args...) +} + +func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { - if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) - return + j := join{ + Name: query, Conds: args, Selects: db.Statement.Selects, + Omits: db.Statement.Omits, JoinType: joinType, } + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + j.On = &where + } + tx.Statement.Joins = append(tx.Statement.Joins, j) + return } } - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) return } // Group specify the group method on the find +// +// // Select the sum age of users with given names +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() @@ -210,6 +284,9 @@ func (db *DB) Group(name string) (tx *DB) { } // Having specify HAVING conditions for GROUP BY +// +// // Select the sum age of users with name jinzhu +// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ @@ -218,9 +295,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { return } -// Order specify order when retrieve records from database -// db.Order("name DESC") -// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// Order specify order when retrieving records from database +// +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() @@ -242,13 +320,27 @@ func (db *DB) Order(value interface{}) (tx *DB) { } // Limit specify the number of records to be retrieved +// +// Limit conditions can be cancelled by using `Limit(-1)`. +// +// // retrieve 3 users +// db.Limit(3).Find(&users) +// // retrieve 3 users into users1, and all users into users2 +// db.Limit(3).Find(&users1).Limit(-1).Find(&users2) func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() - tx.Statement.AddClause(clause.Limit{Limit: limit}) + tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } // Offset specify the number of records to skip before starting to return the records +// +// Offset conditions can be cancelled by using `Offset(-1)`. +// +// // select the third user +// db.Offset(2).First(&user) +// // select the first user by cancelling an earlier chained offset +// db.Offset(5).Offset(-1).First(&user) func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) @@ -256,25 +348,37 @@ func (db *DB) Offset(offset int) (tx *DB) { } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } // -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } // -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) return tx } +func (db *DB) executeScopes() (tx *DB) { + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + return db +} + // Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +// +// // get all users, and preload all non-cancelled orders +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { @@ -284,12 +388,41 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { return } +// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Attrs only adds attributes if the record is not found. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign an email if the record is not found, otherwise ignore provided email +// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.attrs = attrs return } +// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] +// +// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that +// records will be updated even if they are found. +// +// // assign an email regardless of if the record is not found +// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// +// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate +// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.assigns = attrs diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index e08677ac..34d5df41 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ @@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) { clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, - clause.Limit{Limit: 10, Offset: 20}, + clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } diff --git a/clause/clause.go b/clause/clause.go index de19f2e3..1354fc05 100644 --- a/clause/clause.go +++ b/clause/clause.go @@ -20,6 +20,7 @@ type Builder interface { Writer WriteQuoted(field interface{}) AddVar(Writer, ...interface{}) + AddError(error) error } // Clause diff --git a/clause/expression.go b/clause/expression.go index dde00b1d..3140846e 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -126,8 +126,8 @@ func (expr NamedExpr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true - name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' { + name = name[:0] + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/clause/expression_test.go b/clause/expression_test.go index 4826db38..b997bf11 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -94,6 +94,16 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?;", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r\n AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r\n AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1\r AND name2 = @name2", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, + Result: "name1 = ?\r AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, @@ -189,6 +199,11 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{}}, + }, + Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, diff --git a/clause/joins.go b/clause/joins.go index f3e373f2..879892be 100644 --- a/clause/joins.go +++ b/clause/joins.go @@ -9,7 +9,7 @@ const ( RightJoin JoinType = "RIGHT" ) -// Join join clause for from +// Join clause for from type Join struct { Type JoinType Table Table diff --git a/clause/joins_test.go b/clause/joins_test.go new file mode 100644 index 00000000..f1f20ec3 --- /dev/null +++ b/clause/joins_test.go @@ -0,0 +1,101 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestJoin(t *testing.T) { + results := []struct { + name string + join clause.Join + sql string + }{ + { + name: "LEFT JOIN", + join: clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "RIGHT JOIN", + join: clause.Join{ + Type: clause.RightJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "INNER JOIN", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "CROSS JOIN", + join: clause.Join{ + Type: clause.CrossJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + }, + sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", + }, + { + name: "USING", + join: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + { + name: "Expression", + join: clause.Join{ + // Invalid + Type: clause.LeftJoin, + Table: clause.Table{Name: "user"}, + ON: clause.Where{ + Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, + }, + // Valid + Expression: clause.Join{ + Type: clause.InnerJoin, + Table: clause.Table{Name: "user"}, + Using: []string{"id"}, + }, + }, + sql: "INNER JOIN `user` USING (`id`)", + }, + } + for _, result := range results { + t.Run(result.name, func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + result.join.Build(stmt) + if result.sql != stmt.SQL.String() { + t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) + } + }) + } +} diff --git a/clause/limit.go b/clause/limit.go index 184f6025..3edde434 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -1,10 +1,8 @@ package clause -import "strconv" - // Limit limit clause type Limit struct { - Limit int + Limit *int Offset int } @@ -15,16 +13,16 @@ func (limit Limit) Name() string { // Build build where clause func (limit Limit) Build(builder Builder) { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") - builder.WriteString(strconv.Itoa(limit.Limit)) + builder.AddVar(builder, *limit.Limit) } if limit.Offset > 0 { - if limit.Limit > 0 { + if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") - builder.WriteString(strconv.Itoa(limit.Offset)) + builder.AddVar(builder, limit.Offset) } } @@ -33,7 +31,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit != 0 { + if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { limit.Limit = v.Limit } diff --git a/clause/limit_test.go b/clause/limit_test.go index c26294aa..96a7e7e6 100644 --- a/clause/limit_test.go +++ b/clause/limit_test.go @@ -8,6 +8,10 @@ import ( ) func TestLimit(t *testing.T) { + limit0 := 0 + limit10 := 10 + limit50 := 50 + limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string @@ -15,38 +19,56 @@ func TestLimit(t *testing.T) { }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ - Limit: 10, + Limit: &limit10, Offset: 20, }}, - "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 20}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit0}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit0}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, - "SELECT * FROM `users` OFFSET 20", nil, + "SELECT * FROM `users` OFFSET ?", + []interface{}{20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, - "SELECT * FROM `users` OFFSET 30", nil, + "SELECT * FROM `users` OFFSET ?", + []interface{}{30}, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, - "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 20}, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, - "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit10, 30}, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, - "SELECT * FROM `users` LIMIT 10", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + "SELECT * FROM `users` LIMIT ?", + []interface{}{limit10}, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, - "SELECT * FROM `users` OFFSET 30", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, + "SELECT * FROM `users` OFFSET ?", + []interface{}{30}, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, - "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, + "SELECT * FROM `users` LIMIT ? OFFSET ?", + []interface{}{limit50, 30}, }, } diff --git a/clause/locking.go b/clause/locking.go index 290aac92..2bc48ceb 100644 --- a/clause/locking.go +++ b/clause/locking.go @@ -1,5 +1,12 @@ package clause +const ( + LockingStrengthUpdate = "UPDATE" + LockingStrengthShare = "SHARE" + LockingOptionsSkipLocked = "SKIP LOCKED" + LockingOptionsNoWait = "NOWAIT" +) + type Locking struct { Strength string Table Table diff --git a/clause/locking_test.go b/clause/locking_test.go index 0e607312..e45c8e7d 100644 --- a/clause/locking_test.go +++ b/clause/locking_test.go @@ -14,17 +14,21 @@ func TestLocking(t *testing.T) { Vars []interface{} }{ { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, "SELECT * FROM `users` FOR UPDATE", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR SHARE OF `users`", nil, }, { - []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}}, + "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil, + }, } for idx, result := range results { diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 309c5fcd..032bf4a1 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -16,27 +16,27 @@ func (OnConflict) Name() string { // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { - if len(onConflict.Columns) > 0 { - builder.WriteByte('(') - for idx, column := range onConflict.Columns { - if idx > 0 { - builder.WriteByte(',') - } - builder.WriteQuoted(column) - } - builder.WriteString(`) `) - } - - if len(onConflict.TargetWhere.Exprs) > 0 { - builder.WriteString(" WHERE ") - onConflict.TargetWhere.Build(builder) - builder.WriteByte(' ') - } - if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') + } else { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } } if onConflict.DoNothing { diff --git a/clause/select_test.go b/clause/select_test.go index 18bc2693..9c11b90d 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -49,16 +49,18 @@ func TestSelect(t *testing.T) { Exprs: []clause.Expression{ clause.Expr{ SQL: "? as name", - Vars: []interface{}{clause.Eq{ - Column: clause.Column{Name: "age"}, - Value: 18, - }, + Vars: []interface{}{ + clause.Eq{ + Column: clause.Column{Name: "age"}, + Value: 18, + }, }, }, }, }, }, clause.From{}}, - "SELECT `age` = ? as name FROM `users`", []interface{}{18}, + "SELECT `age` = ? as name FROM `users`", + []interface{}{18}, }, } diff --git a/clause/where.go b/clause/where.go index a29401cf..9ac78578 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,6 +21,12 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } + // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { @@ -147,6 +153,11 @@ func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } + if len(exprs) == 1 { + if andCondition, ok := exprs[0].(AndConditions); ok { + exprs = andCondition.Exprs + } + } return NotConditions{Exprs: exprs} } @@ -155,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -182,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/clause/where_test.go b/clause/where_test.go index 35e3dbee..7d5aca1f 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -63,7 +63,7 @@ func TestWhere(t *testing.T) { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, - "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", + "SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?", []interface{}{18, "jinzhu"}, }, { @@ -94,7 +94,7 @@ func TestWhere(t *testing.T) { clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), }, }}, - "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)", + "SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?", []interface{}{"1", 100}, }, { @@ -105,6 +105,14 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", []interface{}{"1", 100}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, + clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, + }}, + "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", + []interface{}{100, 60}, + }, } for idx, result := range results { diff --git a/errors.go b/errors.go index 49cbfe64..cd76f1f5 100644 --- a/errors.go +++ b/errors.go @@ -21,6 +21,10 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrModelAccessibleFieldsRequired model accessible fields required + ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") + // ErrSubQueryRequired sub query required + ErrSubQueryRequired = errors.New("sub query required") // ErrInvalidData unsupported data ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver @@ -41,4 +45,8 @@ var ( ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") // ErrPreloadNotAllowed preload is not allowed when count is used ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") ) diff --git a/finisher_api.go b/finisher_api.go index 663d532b..f97571ed 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -13,7 +13,7 @@ import ( "gorm.io/gorm/utils" ) -// Create insert the value into database +// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) @@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { return tx.callbacks.Create().Execute(tx) } -// CreateInBatches insert the value in batches into database +// CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -33,9 +33,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { var rowsAffected int64 tx = db.getInstance() + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + callFc := func(tx *DB) error { - // the reflection length judgment of the optimized value - reflectLen := reflectValue.Len() for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize if ends > reflectLen { @@ -53,7 +54,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return nil } - if tx.SkipDefaultTransaction { + if tx.SkipDefaultTransaction || reflectLen <= batchSize { tx.AddError(callFc(tx.Session(&Session{}))) } else { tx.AddError(tx.Transaction(callFc)) @@ -68,7 +69,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -101,20 +102,19 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx = tx.callbacks.Update().Execute(tx) + updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) - if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { - return tx.Create(value) - } + if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } + + return updateTx } return } -// First find first record that match given conditions, order by primary key +// First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Last find last record that match given conditions, order by primary key +// Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Find find records that match given conditions +// Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// FindInBatches find records in batches +// FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ @@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { - totalSize = limit.Limit + if limit.Limit != nil { + totalSize = *limit.Limit + } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize @@ -202,7 +204,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat batch++ if result.Error == nil && result.RowsAffected != 0 { - tx.AddError(fc(result, batch)) + fcTx := result.Session(&Session{NewDB: true}) + fcTx.RowsAffected = result.RowsAffected + tx.AddError(fc(fcTx, batch)) } else if result.Error != nil { tx.AddError(result.Error) } @@ -227,7 +231,11 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat break } - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) + if zero { + tx.AddError(ErrPrimaryKeyRequired) + break + } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } @@ -284,7 +292,18 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { } } -// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. +// +// FirstOrInit never modifies the database. It is often used with Assign and Attrs. +// +// // assign an email if the record is not found +// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// +// // assign email regardless of if record is found +// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -310,62 +329,82 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } -// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. +// +// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. +// +// // assign an email if the record is not found +// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "non_existing", Email: "fake@fake.org"} +// // result.RowsAffected -> 1 +// +// // assign email regardless of if record is found +// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) +// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} +// // result.RowsAffected -> 1 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) - if result := queryTx.Find(dest, conds...); result.Error == nil { - if result.RowsAffected == 0 { - if c, ok := result.Statement.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok { - result.assignInterfacesToValue(where.Exprs) - } - } - // initialize with attrs, conds - if len(db.Statement.attrs) > 0 { - result.assignInterfacesToValue(db.Statement.attrs...) - } - - // initialize with attrs, conds - if len(db.Statement.assigns) > 0 { - result.assignInterfacesToValue(db.Statement.assigns...) - } - - return tx.Create(dest) - } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) - assigns := map[string]interface{}{} - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - assigns[column] = eq.Value - case clause.Column: - assigns[column.Name] = eq.Value - default: - } - } - } - - return tx.Model(dest).Updates(assigns) - } else { - tx.Error = result.Error - } + result := queryTx.Find(dest, conds...) + if result.Error != nil { + tx.Error = result.Error + return tx } + + if result.RowsAffected == 0 { + if c, ok := result.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + result.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(db.Statement.attrs) > 0 { + result.assignInterfacesToValue(db.Statement.attrs...) + } + + // initialize with attrs, conds + if len(db.Statement.assigns) > 0 { + result.assignInterfacesToValue(db.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for i := 0; i < len(exprs); i++ { + expr := exprs[i] + + if eq, ok := expr.(clause.AndConditions); ok { + exprs = append(exprs, eq.Exprs...) + } else if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + } + } + } + + return tx.Model(dest).Updates(assigns) + } + return tx } -// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } -// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values @@ -386,7 +425,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -480,7 +521,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return rows, tx.Error } -// Scan scan value to a struct +// Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() @@ -494,6 +535,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 + tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } @@ -505,9 +547,10 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Model(&users).Pluck("age", &ages) +// Pluck queries a single column from a model, returning in the slice dest. E.g.: +// +// var ages []int64 +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { @@ -548,7 +591,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } -// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error @@ -570,7 +614,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true @@ -581,7 +627,6 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if err != nil { return } - defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -589,8 +634,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() } - - err = fc(db.Session(&Session{})) + err = fc(db.Session(&Session{NewDB: db.clone == 1})) } else { tx := db.Begin(opts...) if tx.Error != nil { @@ -614,7 +658,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } -// Begin begins a transaction +// Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement @@ -643,7 +687,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { return tx } -// Commit commit a transaction +// Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) @@ -653,7 +697,7 @@ func (db *DB) Commit() *DB { return db } -// Rollback rollback a transaction +// Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { @@ -667,7 +711,21 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } @@ -676,14 +734,28 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } return db } -// Exec execute raw sql +// Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} diff --git a/go.mod b/go.mod index 57362745..deb61b74 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module gorm.io/gorm -go 1.14 +go 1.18 require ( github.com/jinzhu/inflection v1.0.0 - github.com/jinzhu/now v1.1.4 + github.com/jinzhu/now v1.1.5 ) diff --git a/go.sum b/go.sum index 50fbba2f..bd6104c9 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go index 6a6bb032..775cd3de 100644 --- a/gorm.go +++ b/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -37,6 +38,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // IgnoreRelationshipsWhenMigrating + IgnoreRelationshipsWhenMigrating bool // DisableNestedTransaction disable nested transaction DisableNestedTransaction bool // AllowGlobalUpdate allow global update @@ -45,6 +48,8 @@ type Config struct { QueryFields bool // CreateBatchSize default create batch size CreateBatchSize int + // TranslateError enabling error translation + TranslateError bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -142,7 +147,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.NamingStrategy == nil { - config.NamingStrategy = schema.NamingStrategy{} + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { @@ -175,17 +180,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if config.Dialector != nil { err = config.Dialector.Initialize(db) - } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: map[string]Stmt{}, - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), + if err != nil { + if db, _ := db.DB(); db != nil { + _ = db.Close() + } + } } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -246,16 +251,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { @@ -300,7 +319,8 @@ func (db *DB) WithContext(ctx context.Context) *DB { // Debug start debug mode func (db *DB) Debug() (tx *DB) { - return db.Session(&Session{ + tx = db.getInstance() + return tx.Session(&Session{ Logger: db.Logger.LogMode(logger.Info), }) } @@ -336,10 +356,18 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { - if db.Error == nil { - db.Error = err - } else if err != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, err) + if err != nil { + if db.Config.TranslateError { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } + } + + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } } return db.Error } @@ -347,12 +375,20 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - - if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } - if sqldb, ok := connPool.(*sql.DB); ok { + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } + } + + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } @@ -366,11 +402,12 @@ func (db *DB) getInstance() *DB { if db.clone == 1 { // clone with new statement tx.Statement = &Statement{ - DB: tx, - ConnPool: db.Statement.ConnPool, - Context: db.Statement.Context, - Clauses: map[string]clause.Clause{}, - Vars: make([]interface{}, 0, 8), + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + SkipHooks: db.Statement.SkipHooks, } } else { // with clone statement @@ -412,7 +449,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { - return fmt.Errorf("failed to found relation: %s", field) + return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { @@ -455,12 +492,12 @@ func (db *DB) Use(plugin Plugin) error { // ToSQL for generate SQL string. // -// db.ToSQL(func(tx *gorm.DB) *gorm.DB { -// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) -// .Limit(10).Offset(5) -// .Order("name ASC") -// .First(&User{}) -// }) +// db.ToSQL(func(tx *gorm.DB) *gorm.DB { +// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) +// .Limit(10).Offset(5) +// .Order("name ASC") +// .First(&User{}) +// }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) stmt := tx.Statement diff --git a/interfaces.go b/interfaces.go index 32d49605..3bcc3d57 100644 --- a/interfaces.go +++ b/interfaces.go @@ -26,6 +26,10 @@ type Plugin interface { Initialize(*DB) error } +type ParamsFilter interface { + ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) +} + // ConnPool db conns pool interface type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) @@ -82,3 +86,7 @@ type Rows interface { Err() error Close() error } + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..253f0325 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "os" "time" @@ -55,6 +55,7 @@ type Config struct { SlowThreshold time.Duration Colorful bool IgnoreRecordNotFoundError bool + ParameterizedQueries bool LogLevel LogLevel } @@ -68,8 +69,8 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, @@ -77,7 +78,7 @@ var ( IgnoreRecordNotFoundError: false, Colorful: true, }) - // Recorder Recorder logger records running SQL into a recorder instance + // Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) @@ -128,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface { } // Info print info -func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages -func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages -func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { +func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message -func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { +// +//nolint:cyclop +func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= Silent { return } @@ -181,6 +184,14 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i } } +// ParamsFilter filter params +func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.Config.ParameterizedQueries { + return sql, nil + } + return sql, params +} + type traceRecorder struct { Interface BeginAt time.Time @@ -189,8 +200,8 @@ type traceRecorder struct { Err error } -// New new trace recorder -func (l traceRecorder) New() *traceRecorder { +// New trace recorder +func (l *traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } diff --git a/logger/sql.go b/logger/sql.go index ae5dcde4..c3390cea 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -28,8 +28,25 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values +var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) + +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -75,24 +92,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper } else { vars[idx] = nullStr } } case []byte: if s := string(v); isPrintable(s) { - vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: vars[idx] = utils.ToString(v) - case float64, float32: - vars[idx] = fmt.Sprintf("%.6f", v) + case float32: + vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: v = strings.ReplaceAll(v, "\\"+escaper, "") @@ -110,6 +129,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { @@ -117,7 +142,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper } } } @@ -144,9 +169,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") - for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) - } + + sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { + num := v[1 : len(v)-1] + n, _ := strconv.Atoi(num) + + // position var start from 1 ($1, $2) + n -= 1 + if n >= 0 && n <= len(vars)-1 { + return vars[n] + } + return v + }) } return sql diff --git a/logger/sql_test.go b/logger/sql_test.go index f759f8c6..e83544a8 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,20 +31,24 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper } func TestExplainSQL(t *testing.T) { type role string type password []byte + type intType int + type floatType float64 var ( - tt = now.MustParse("2020-02-23 11:10:10") - myrole = role("admin") - pwd = password([]byte("pass")) - jsVal = []byte(`{"Name":"test","Val":"test"}`) - js = JSON(jsVal) - esVal = []byte(`{"Name":"test","Val":"test"}`) - es = ExampleStruct{Name: "test", Val: "test"} + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password("pass") + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + intVal intType = 1 + floatVal floatType = 1.23 ) results := []struct { @@ -69,19 +73,19 @@ func TestExplainSQL(t *testing.T) { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, - Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", @@ -95,6 +99,30 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, E"w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, + }, } for idx, r := range results { diff --git a/migrator.go b/migrator.go index 52443877..3d2b032b 100644 --- a/migrator.go +++ b/migrator.go @@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator { // apply scopes to migrator for len(tx.Statement.scopes) > 0 { - scopes := tx.Statement.scopes - tx.Statement.scopes = nil - for _, scope := range scopes { - tx = scope(tx) - } + tx = tx.executeScopes() } return tx.Dialector.Migrator(tx.Session(&Session{})) @@ -30,9 +26,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { // ViewOption view option type ViewOption struct { - Replace bool - CheckOption string - Query *DB + Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` + CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` + Query *DB // required subquery. } // ColumnType column type interface @@ -51,6 +47,23 @@ type ColumnType interface { DefaultValue() (value string, ok bool) } +type Index interface { + Table() string + Name() string + Columns() []string + PrimaryKey() (isPrimaryKey bool, ok bool) + Unique() (unique bool, ok bool) + Option() string +} + +// TableType table type interface +type TableType interface { + Schema() string + Name() string + Type() string + Comment() (comment string, ok bool) +} + // Migrator migrator interface type Migrator interface { // AutoMigrate @@ -59,6 +72,7 @@ type Migrator interface { // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error @@ -66,12 +80,15 @@ type Migrator interface { HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error GetTables() (tableList []string, err error) + TableType(dst interface{}) (TableType, error) // Columns AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. + MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) @@ -90,4 +107,5 @@ type Migrator interface { DropIndex(dst interface{}, name string) error HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error + GetIndexes(dst interface{}) ([]Index, error) } diff --git a/migrator/index.go b/migrator/index.go new file mode 100644 index 00000000..8845da95 --- /dev/null +++ b/migrator/index.go @@ -0,0 +1,43 @@ +package migrator + +import "database/sql" + +// Index implements gorm.Index interface +type Index struct { + TableName string + NameValue string + ColumnList []string + PrimaryKeyValue sql.NullBool + UniqueValue sql.NullBool + OptionValue string +} + +// Table return the table name of the index. +func (idx Index) Table() string { + return idx.TableName +} + +// Name return the name of the index. +func (idx Index) Name() string { + return idx.NameValue +} + +// Columns return the columns of the index +func (idx Index) Columns() []string { + return idx.ColumnList +} + +// PrimaryKey returns the index is primary key or not. +func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { + return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid +} + +// Unique returns whether the index is unique or not. +func (idx Index) Unique() (unique bool, ok bool) { + return idx.UniqueValue.Bool, idx.UniqueValue.Valid +} + +// Option return the optional attribute of the index +func (idx Index) Option() string { + return idx.OptionValue +} diff --git a/migrator/migrator.go b/migrator/migrator.go index 93f4c5d0..acce5df2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -3,20 +3,32 @@ package migrator import ( "context" "database/sql" + "errors" "fmt" "reflect" "regexp" + "strconv" "strings" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) -var ( - regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) - regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) -) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) + +// TODO:? Create const vars for raw sql queries ? + +var _ gorm.Migrator = (*Migrator)(nil) // Migrator m struct type Migrator struct { @@ -30,6 +42,16 @@ type Config struct { gorm.Dialector } +type printSQLLogger struct { + logger.Interface +} + +func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + fmt.Println(sql + ";") + l.Interface.Trace(ctx, begin, fc, err) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string @@ -72,10 +94,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL += " NOT NULL" } - if field.Unique { - expr.SQL += " UNIQUE" - } - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} @@ -89,23 +107,35 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { return } +func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { + queryTx = m.DB.Session(&gorm.Session{}) + execTx = queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + return queryTx, execTx +} + // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) - if !tx.Migrator().HasTable(value) { - if err := tx.Migrator().CreateTable(value); err != nil { + queryTx, execTx := m.GetQueryAndExecTx() + if !queryTx.Migrator().HasTable(value) { + if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { - columnTypes, err := m.DB.Migrator().ColumnTypes(value) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } - + var ( + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() + ) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType for _, columnType := range columnTypes { @@ -117,37 +147,43 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := tx.Migrator().AddColumn(value, dbName); err != nil { + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { return err } - } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { - // found, smart migrate - return err } } - for _, rel := range stmt.Schema.Relationships.Relations { - if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } if constraint := rel.ParseConstraint(); constraint != nil && - constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } - } - - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { + if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } } - for _, idx := range stmt.Schema.ParseIndexes() { - if !tx.Migrator().HasIndex(value, idx.Name) { - if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + for _, chk := range parseCheckConstraints { + if !queryTx.Migrator().HasConstraint(value, chk.Name) { + if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + + for _, idx := range parseIndexes { + if !queryTx.Migrator().HasIndex(value, idx.Name) { + if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } @@ -174,7 +210,7 @@ func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} @@ -185,7 +221,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } @@ -193,7 +229,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," - primaryKeys := []interface{}{} + primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields { primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) } @@ -204,8 +240,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { - if errr == nil { - errr = tx.Migrator().CreateIndex(value, name) + if err == nil { + err = tx.Migrator().CreateIndex(value, name) } }(value, idx.Name) } else { @@ -223,15 +259,18 @@ func (m Migrator) CreateTable(values ...interface{}) error { } createTableSQL += "," - values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } - for _, rel := range stmt.Schema.Relationships.Relations { - if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range stmt.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) + sql, vars := constraint.Build() createTableSQL += sql + "," values = append(values, vars...) } @@ -239,6 +278,11 @@ func (m Migrator) CreateTable(values ...interface{}) error { } } + for _, uni := range stmt.Schema.ParseUniqueConstraints() { + createTableSQL += "CONSTRAINT ? UNIQUE (?)," + values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) + } + for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) @@ -252,8 +296,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += fmt.Sprint(tableOption) } - errr = tx.Exec(createTableSQL, values...).Error - return errr + err = tx.Exec(createTableSQL, values...).Error + return err }); err != nil { return err } @@ -402,32 +446,58 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if field.IgnoreMigration { + return nil + } + // found, smart migrate - fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) - alterColumn := false + var ( + alterColumn bool + isSameType = fullDataType == realDataType + ) - // check size - if length, ok := columnType.Length(); length != int64(field.Size) { - if length > 0 && field.Size > 0 { - alterColumn = true - } else { - // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here - matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) - matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && - (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + if !field.PrimaryKey { + // check type + if !strings.HasPrefix(fullDataType, realDataType) { + // check type aliases + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + + if !isSameType { alterColumn = true } } } - // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true + if !isSameType { + // check size + if length, ok := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if !field.PrimaryKey && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } } } @@ -439,19 +509,29 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && unique != field.Unique { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value - if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { - // not primary key - if !field.PrimaryKey { + if !field.PrimaryKey { + currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) + dv, dvNotNull := columnType.DefaultValue() + if dvNotNull && !currentDefaultNotNull { + // default value -> null alterColumn = true + } else if !dvNotNull && currentDefaultNotNull { + // null -> default value + alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue + } } } @@ -463,13 +543,39 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + if unique && !field.Unique { + return m.DB.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return m.DB.Migrator().CreateConstraint(value, constraint) + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) @@ -499,47 +605,76 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } -// CreateView create view +// CreateView create view from Query in gorm.ViewOption. +// Query in gorm.ViewOption is a [subquery] +// +// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 +// q := DB.Model(&User{}).Where("age > ?", 20) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) +// +// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION +// q := DB.Model(&User{}) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) +// +// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery func (m Migrator) CreateView(name string, option gorm.ViewOption) error { - return gorm.ErrNotImplemented + if option.Query == nil { + return gorm.ErrSubQueryRequired + } + + sql := new(strings.Builder) + sql.WriteString("CREATE ") + if option.Replace { + sql.WriteString("OR REPLACE ") + } + sql.WriteString("VIEW ") + m.QuoteTo(sql, name) + sql.WriteString(" AS ") + + m.DB.Statement.AddVar(sql, option.Query) + + if option.CheckOption != "" { + sql.WriteString(" ") + sql.WriteString(option.CheckOption) + } + return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error } // DropView drop view func (m Migrator) DropView(name string) error { - return gorm.ErrNotImplemented -} - -func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { - sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" - if constraint.OnDelete != "" { - sql += " ON DELETE " + constraint.OnDelete - } - - if constraint.OnUpdate != "" { - sql += " ON UPDATE " + constraint.OnUpdate - } - - var foreignKeys, references []interface{} - for _, field := range constraint.ForeignKeys { - foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) - } - - for _, field := range constraint.References { - references = append(references, clause.Column{Name: field.DBName}) - } - results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) - return + return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } // GuessConstraintAndTable guess statement's constraint and it's table based on name -func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { +// +// Deprecated: use GuessConstraintInterfaceAndTable instead. +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + switch c := constraint.(type) { + case *schema.Constraint: + return c, nil, table + case *schema.CheckConstraint: + return nil, c, table + default: + return nil, nil, table + } +} + +// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name +// nolint:cyclop +func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { - return nil, nil, stmt.Table + return nil, stmt.Table } checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { - return nil, &chk, stmt.Table + return &chk, stmt.Table + } + + uniqueConstraints := stmt.Schema.ParseUniqueConstraints() + if uni, ok := uniqueConstraints[name]; ok { + return &uni, stmt.Table } getTable := func(rel *schema.Relationship) string { @@ -554,7 +689,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } @@ -562,40 +697,39 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ for k := range checkConstraints { if checkConstraints[k].Field == field { v := checkConstraints[k] - return nil, &v, stmt.Table + return &v, stmt.Table + } + } + + for k := range uniqueConstraints { + if uniqueConstraints[k].Field == field { + v := uniqueConstraints[k] + return &v, stmt.Table } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { - return constraint, nil, getTable(rel) + return constraint, getTable(rel) } } } - return nil, nil, stmt.Schema.Table + return nil, stmt.Schema.Table } // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - if chk != nil { - return m.DB.Exec( - "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", - m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, - ).Error - } - + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } - sql, values := buildConstraint(constraint) + sql, values := constraint.Build() return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } - return nil }) } @@ -603,11 +737,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) @@ -618,11 +750,9 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } return m.DB.Raw( @@ -759,7 +889,8 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} - if err := dep.Parse(value); err != nil { + // support for special table name + if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { @@ -767,26 +898,31 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i } parsedSchemas[dep.Statement.Schema] = true - for _, rel := range dep.Schema.Relationships.Relations { - if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { - dep.Depends = append(dep.Depends, c.ReferenceSchema) - } + if !m.DB.IgnoreRelationshipsWhenMigrating { + for _, rel := range dep.Schema.Relationships.Relations { + if rel.Field.IgnoreMigration { + continue + } + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } - if rel.Type == schema.HasOne || rel.Type == schema.HasMany { - beDependedOn[rel.FieldSchema] = true - } + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } - if rel.JoinTable != nil { - // append join value - defer func(rel *schema.Relationship, joinValue interface{}) { - if !beDependedOn[rel.FieldSchema] { - dep.Depends = append(dep.Depends, rel.FieldSchema) - } else { - fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() - parseDependence(fieldValue, autoAdd) - } - parseDependence(joinValue, autoAdd) - }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } } } @@ -843,3 +979,18 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { } return clause.Table{Name: stmt.Table} } + +// GetIndexes return Indexes []gorm.Index and execErr error +func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { + return nil, errors.New("not support") +} + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} + +// TableType return tableType gorm.TableType and execErr error +func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { + return nil, errors.New("not support") +} diff --git a/migrator/table_type.go b/migrator/table_type.go new file mode 100644 index 00000000..ed6e42a0 --- /dev/null +++ b/migrator/table_type.go @@ -0,0 +1,33 @@ +package migrator + +import ( + "database/sql" +) + +// TableType table type implements TableType interface +type TableType struct { + SchemaValue string + NameValue string + TypeValue string + CommentValue sql.NullString +} + +// Schema returns the schema of the table. +func (ct TableType) Schema() string { + return ct.SchemaValue +} + +// Name returns the name of the table. +func (ct TableType) Name() string { + return ct.NameValue +} + +// Type returns the type of the table. +func (ct TableType) Type() string { + return ct.TypeValue +} + +// Comment returns the comment of current table. +func (ct TableType) Comment() (comment string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} diff --git a/model.go b/model.go index 3334d17c..fa705df1 100644 --- a/model.go +++ b/model.go @@ -4,9 +4,10 @@ import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } +// +// type User struct { +// gorm.Model +// } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time diff --git a/prepare_stmt.go b/prepare_stmt.go index b062b0d6..4d533885 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,30 +3,44 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" + "reflect" "sync" ) type Stmt struct { *sql.Stmt Transaction bool + prepared chan struct{} + prepareErr error } type PreparedStmtDB struct { - Stmts map[string]Stmt + Stmts map[string]*Stmt PreparedSQL []string Mux *sync.RWMutex ConnPool } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), } +} +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -42,31 +56,72 @@ func (db *PreparedStmtDB) Close() { } } +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() + + for _, stmt := range sdb.Stmts { + go stmt.Close() + } + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() - return stmt, nil + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } db.Mux.RUnlock() db.Mux.Lock() - defer db.Mux.Unlock() - // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - return stmt, nil - } else if ok { - go stmt.Close() + db.Mux.Unlock() + // wait for other goroutines prepared + <-stmt.prepared + if stmt.prepareErr != nil { + return Stmt{}, stmt.prepareErr + } + + return *stmt, nil } + // cache preparing stmt first + cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} + db.Stmts[query] = &cacheStmt + db.Mux.Unlock() + + // prepare completed + defer close(cacheStmt.prepared) + + // Reason why cannot lock conn.PrepareContext + // suppose the maxopen is 1, g1 is creating record and g2 is querying record. + // 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1. + // 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release. + // 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release. stmt, err := conn.PrepareContext(ctx, query) - if err == nil { - db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} - db.PreparedSQL = append(db.PreparedSQL, query) + if err != nil { + cacheStmt.prepareErr = err + db.Mux.Lock() + delete(db.Stmts, query) + db.Mux.Unlock() + return Stmt{}, err } - return db.Stmts[query], err + db.Mux.Lock() + cacheStmt.Stmt = stmt + db.PreparedSQL = append(db.PreparedSQL, query) + db.Mux.Unlock() + + return cacheStmt, nil } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { @@ -74,6 +129,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -81,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -95,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -114,20 +182,32 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg return &sql.Row{} } +func (db *PreparedStmtDB) Ping() error { + conn, err := db.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} + type PreparedStmtTX struct { Tx PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction @@ -137,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -152,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -170,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg } return &sql.Row{} } + +func (tx *PreparedStmtTX) Ping() error { + conn, err := tx.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} diff --git a/scan.go b/scan.go index ad3734d8..415b9f0d 100644 --- a/scan.go +++ b/scan.go @@ -4,10 +4,10 @@ import ( "database/sql" "database/sql/driver" "reflect" - "strings" "time" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // prepareValues prepare values slice @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -65,26 +65,49 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - + joinedNestedSchemaMap := make(map[string]interface{}) for idx, field := range fields { - if field != nil { - if len(joinFields) == 0 || joinFields[idx][0] == nil { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + if field == nil { + continue + } - relValue.Set(reflect.New(relValue.Type().Elem())) + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) + currentReflectValue = relValue } - // release data to pool - field.NewValuePool.Put(values[idx]) + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) + } } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } @@ -156,11 +179,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - selectedColumnsMap = make(map[string]int, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + joinFields [][]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -193,29 +215,45 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { + matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - if curIndex, ok := selectedColumnsMap[column]; ok { - for fieldIndex, selectField := range sch.Fields[curIndex+1:] { + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 - fields[idx] = selectField - break + if count == 0 { + matchedFieldCount[column]++ + fields[idx] = selectField + break + } + count-- } } } else { - fields[idx] = field - selectedColumnsMap[column] = idx + matchedFieldCount[column] = 1 } - } else if names := strings.Split(column, "__"); len(names) > 1 { + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) + joinFields = make([][]*schema.Field, len(columns)) } - joinFields[idx] = [2]*schema.Field{rel.Field, field} + relFields = append(relFields, field) + joinFields[idx] = relFields continue } } @@ -229,11 +267,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value + var ( + elem reflect.Value + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + if isArrayKind { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } else { + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } + } } for initialized || rows.Next() { @@ -260,10 +311,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } diff --git a/schema/check.go b/schema/check.go deleted file mode 100644 index 89e732d3..00000000 --- a/schema/check.go +++ /dev/null @@ -1,35 +0,0 @@ -package schema - -import ( - "regexp" - "strings" -) - -// reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") - -type Check struct { - Name string - Constraint string // length(phone) >= 10 - *Field -} - -// ParseCheckConstraints parse schema check constraints -func (schema *Schema) ParseCheckConstraints() map[string]Check { - checks := map[string]Check{} - for _, field := range schema.FieldsByDBName { - if chk := field.TagSettings["CHECK"]; chk != "" { - names := strings.Split(chk, ",") - if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { - checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} - } else { - if names[0] == "" { - chk = strings.Join(names[1:], ",") - } - name := schema.namer.CheckerName(schema.Table, field.DBName) - checks[name] = Check{Name: name, Constraint: chk, Field: field} - } - } - } - return checks -} diff --git a/schema/constraint.go b/schema/constraint.go new file mode 100644 index 00000000..80a743a8 --- /dev/null +++ b/schema/constraint.go @@ -0,0 +1,66 @@ +package schema + +import ( + "regexp" + "strings" + + "gorm.io/gorm/clause" +) + +// reg match english letters and midline +var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) + +type CheckConstraint struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +func (chk *CheckConstraint) GetName() string { return chk.Name } + +func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { + checks := map[string]CheckConstraint{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} + +type UniqueConstraint struct { + Name string + Field *Field +} + +func (uni *UniqueConstraint) GetName() string { return uni.Name } + +func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { + return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} +} + +// ParseUniqueConstraints parse schema unique constraints +func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { + uniques := make(map[string]UniqueConstraint) + for _, field := range schema.Fields { + if field.Unique { + name := schema.namer.UniqueName(schema.Table, field.DBName) + uniques[name] = UniqueConstraint{Name: name, Field: field} + } + } + return uniques +} diff --git a/schema/check_test.go b/schema/constraint_test.go similarity index 59% rename from schema/check_test.go rename to schema/constraint_test.go index eda043b7..6fcb1b85 100644 --- a/schema/check_test.go +++ b/schema/constraint_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserCheck struct { @@ -20,7 +21,7 @@ func TestParseCheck(t *testing.T) { t.Fatalf("failed to parse user check, got error %v", err) } - results := map[string]schema.Check{ + results := map[string]schema.CheckConstraint{ "name_checker": { Name: "name_checker", Constraint: "name <> 'jinzhu'", @@ -53,3 +54,31 @@ func TestParseCheck(t *testing.T) { } } } + +func TestParseUniqueConstraints(t *testing.T) { + type UserUnique struct { + Name1 string `gorm:"unique"` + Name2 string `gorm:"uniqueIndex"` + } + + user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + constraints := user.ParseUniqueConstraints() + + results := map[string]schema.UniqueConstraint{ + "uni_user_uniques_name1": { + Name: "uni_user_uniques_name1", + Field: &schema.Field{Name: "Name1", Unique: true}, + }, + } + for k, result := range results { + v, ok := constraints[k] + if !ok { + t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints) + } + tests.AssertObjEqual(t, result, v, "Name") + tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex") + } +} diff --git a/schema/field.go b/schema/field.go index d6df6596..ca2e1148 100644 --- a/schema/field.go +++ b/schema/field.go @@ -49,6 +49,8 @@ const ( Bytes DataType = "bytes" ) +const DefaultAutoIncrementIncrement int64 = 1 + // Field is the representation of model schema's field type Field struct { Name string @@ -87,6 +89,16 @@ type Field struct { Set func(context.Context, reflect.Value, interface{}) error Serializer SerializerInterface NewValuePool FieldNewValuePool + + // In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable. + // When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique. + // It causes field unnecessarily migration. + // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. + UniqueIndex string +} + +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") } // ParseField parses reflect.StructField to Field @@ -115,7 +127,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], - AutoIncrementIncrement: 1, + AutoIncrementIncrement: DefaultAutoIncrementIncrement, } for field.IndirectFieldType.Kind() == reflect.Ptr { @@ -174,7 +186,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = String field.Serializer = v } else { - var serializerName = field.TagSettings["JSON"] + serializerName := field.TagSettings["JSON"] if serializerName == "" { serializerName = field.TagSettings["SERIALIZER"] } @@ -403,18 +415,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if ef.PrimaryKey { - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { + if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) { ef.PrimaryKey = false if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { ef.AutoIncrement = false } - if ef.DefaultValue == "" { + if !ef.AutoIncrement && ef.DefaultValue == "" { ef.HasDefaultValue = false } } @@ -472,9 +480,6 @@ func (field *Field) setupValuerAndSetter() { oldValuerOf := field.ValueOf field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { value, zero := oldValuerOf(ctx, v) - if zero { - return value, zero - } s, ok := value.(SerializerValuerInterface) if !ok { @@ -487,7 +492,7 @@ func (field *Field) setupValuerAndSetter() { Destination: v, Context: ctx, fieldValue: value, - }, false + }, zero } } @@ -607,6 +612,22 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) } + case **int: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } case int64: field.ReflectValueOf(ctx, value).SetInt(data) case int: @@ -643,7 +664,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -652,7 +673,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -671,6 +692,22 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) } + case **uint: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) case uint: @@ -701,7 +738,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } @@ -723,6 +760,10 @@ func (field *Field) setupValuerAndSetter() { if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) } + case **float32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) + } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) case float32: @@ -813,7 +854,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -849,14 +890,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -877,14 +916,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() @@ -913,6 +950,8 @@ func (field *Field) setupValuerAndSetter() { sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() } + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { if s.fieldValue != nil { @@ -920,11 +959,12 @@ func (field *Field) setupValuerAndSetter() { } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } else if sameType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) } } else { err = oldFieldSetter(ctx, value, v) @@ -936,11 +976,15 @@ func (field *Field) setupValuerAndSetter() { func (field *Field) setupNewValuePool() { if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.NewValuePool = &sync.Pool{ New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) return &serializer{ Field: field, - Serializer: reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface), + Serializer: si.Interface().(SerializerInterface), } }, } diff --git a/schema/index.go b/schema/index.go index 5003c742..f4f36751 100644 --- a/schema/index.go +++ b/schema/index.go @@ -13,8 +13,8 @@ type Index struct { Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string - Option string // WITH PARSER parser_name - Fields []IndexOption + Option string // WITH PARSER parser_name + Fields []IndexOption // Note: IndexOption's Field maybe the same } type IndexOption struct { @@ -65,7 +65,11 @@ func (schema *Schema) ParseIndexes() map[string]Index { } } } - + for _, index := range indexes { + if index.Class == "UNIQUE" && len(index.Fields) == 1 { + index.Fields[0].Field.UniqueIndex = index.Name + } + } return indexes } diff --git a/schema/index_test.go b/schema/index_test.go index 1fe31cc1..2f1e36af 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -1,11 +1,11 @@ package schema_test import ( - "reflect" "sync" "testing" "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" ) type UserIndex struct { @@ -19,6 +19,7 @@ type UserIndex struct { OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` + Name8 string `gorm:"index:,length:10;index:,collate:utf8"` // Composite Index: Flattened structure. Data0A string `gorm:"index:,composite:comp_id0"` @@ -65,7 +66,7 @@ func TestParseIndex(t *testing.T) { "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", @@ -81,7 +82,7 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", @@ -102,18 +103,27 @@ func TestParseIndex(t *testing.T) { }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "type": { Name: "type", Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, + "idx_user_indices_name8": { + Name: "idx_user_indices_name8", + Type: "", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "Name8"}, Length: 10}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "Name8"}, Collate: "utf8"}, + }, + }, "idx_user_indices_comp_id0": { Name: "idx_user_indices_comp_id0", Type: "", @@ -146,37 +156,109 @@ func TestParseIndex(t *testing.T) { }, } - indices := user.ParseIndexes() + CheckIndices(t, results, user.ParseIndexes()) +} - for k, result := range results { - v, ok := indices[k] - if !ok { - t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) - } +func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { + type IndexTest struct { + FieldA string `gorm:"unique;index"` // unique and index + FieldB string `gorm:"unique"` // unique - for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { - if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { - t.Errorf( - "index %v %v should equal, expects %v, got %v", - k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), - ) + FieldC string `gorm:"index:,unique"` // uniqueIndex + FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index + + FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex + FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"` + + FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index + FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"` + + FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex + + FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex + } + indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index, got error %v", err) + } + indices := indexSchema.ParseIndexes() + CheckIndices(t, map[string]schema.Index{ + "idx_index_tests_field_a": { + Name: "idx_index_tests_field_a", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}}, + }, + "idx_index_tests_field_c": { + Name: "idx_index_tests_field_c", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}}, + }, + "idx_index_tests_field_d": { + Name: "idx_index_tests_field_d", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldD"}}, + // Note: Duplicate Columns + {Field: &schema.Field{Name: "FieldD"}}, + }, + }, + "uniq_field_e1_e2": { + Name: "uniq_field_e1_e2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldE1"}}, + {Field: &schema.Field{Name: "FieldE2"}}, + }, + }, + "idx_index_tests_field_f1": { + Name: "idx_index_tests_field_f1", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}}, + }, + "uniq_field_f1_f2": { + Name: "uniq_field_f1_f2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldF1"}}, + {Field: &schema.Field{Name: "FieldF2"}}, + }, + }, + "idx_index_tests_field_g": { + Name: "idx_index_tests_field_g", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}}, + }, + "uniq_field_h1_h2": { + Name: "uniq_field_h1_h2", + Class: "UNIQUE", + Fields: []schema.IndexOption{ + {Field: &schema.Field{Name: "FieldH1", Unique: true}}, + {Field: &schema.Field{Name: "FieldH2"}}, + }, + }, + }, indices) +} + +func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) { + for k, ei := range expected { + t.Run(k, func(t *testing.T) { + ai, ok := actual[k] + if !ok { + t.Errorf("expected index %q but actual missing", k) + return } - } - - for idx, ef := range result.Fields { - rf := v.Fields[idx] - if rf.Field.Name != ef.Field.Name { - t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) + tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option") + if len(ei.Fields) != len(ai.Fields) { + t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields)) + return } - - for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { - if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { - t.Errorf( - "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, - reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), - ) - } + for i, ef := range ei.Fields { + af := ai.Fields[i] + tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length") } - } + }) + delete(actual, k) + } + for k := range actual { + t.Errorf("unexpected index %q", k) } } diff --git a/schema/interfaces.go b/schema/interfaces.go index a75a33c0..306d4f4e 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -4,6 +4,12 @@ import ( "gorm.io/gorm/clause" ) +// ConstraintInterface database constraint interface +type ConstraintInterface interface { + GetName() string + Build() (sql string, vars []interface{}) +} + // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string diff --git a/schema/naming.go b/schema/naming.go index a258beed..e6fb81b2 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -19,6 +19,7 @@ type Namer interface { RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string + UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer @@ -26,12 +27,15 @@ type Replacer interface { Replace(name string) string } +var _ Namer = (*NamingStrategy)(nil) + // NamingStrategy tables, columns naming strategy type NamingStrategy struct { - TablePrefix string - SingularTable bool - NameReplacer Replacer - NoLowerCase bool + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int } // TableName convert string to table name @@ -84,17 +88,26 @@ func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } +// UniqueName generate unique constraint name +func (ns NamingStrategy) UniqueName(table, column string) string { + return ns.formatName("uni", table, ns.toDBName(column)) +} + func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, }, "_"), ".", "_") - if utf8.RuneCountInString(formattedName) > 64 { + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/schema/naming_test.go b/schema/naming_test.go index 3f598c33..ab7a5e31 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -189,8 +189,17 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) { } } +func TestFormatNameWithStringLongerThan63Characters(t *testing.T) { + ns := NamingStrategy{IdentifierMaxLength: 63} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} + func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { - ns := NamingStrategy{} + ns := NamingStrategy{IdentifierMaxLength: 64} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { diff --git a/schema/relationship.go b/schema/relationship.go index b5100897..2e94fc2c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -27,6 +27,8 @@ type Relationships struct { HasMany []*Relationship Many2Many []*Relationship Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships } type Relationship struct { @@ -74,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return nil } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - schema.buildPolymorphicRelation(relation, field, polymorphic) + if hasPolymorphicRelation(field.TagSettings) { + schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { @@ -87,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, + field.Name) } } @@ -106,7 +109,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } if schema.err == nil { - schema.Relationships.Relations[relation.Name] = relation + schema.setRelation(relation) switch relation.Type { case HasOne: schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) @@ -122,34 +125,100 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } -// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` -// type User struct { -// Toys []Toy `gorm:"polymorphic:Owner;"` -// } -// type Pet struct { -// Toy Toy `gorm:"polymorphic:Owner;"` -// } -// type Toy struct { -// OwnerID int -// OwnerType string -// } -func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { - relation.Polymorphic = &Polymorphic{ - Value: schema.Table, - PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], - PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], +// hasPolymorphicRelation check if has polymorphic relation +// 1. `POLYMORPHIC` tag +// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag +func hasPolymorphicRelation(tagSettings map[string]string) bool { + if _, ok := tagSettings["POLYMORPHIC"]; ok { + return true } + _, hasType := tagSettings["POLYMORPHICTYPE"] + _, hasId := tagSettings["POLYMORPHICID"] + + return hasType && hasId +} + +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { + polymorphic := field.TagSettings["POLYMORPHIC"] + + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + } + + var ( + typeName = polymorphic + "Type" + typeId = polymorphic + "ID" + ) + + if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { + typeName = strings.TrimSpace(value) + } + + if value, ok := field.TagSettings["POLYMORPHICID"]; ok { + typeId = strings.TrimSpace(value) + } + + relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] + relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", + relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -161,10 +230,17 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, + schema, field.Name) } } + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", + relation.FieldSchema, schema, field.Name) + return + } + // use same data type for foreign keys if copyableDataType(primaryKeyField.DataType) { relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType @@ -191,7 +267,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,21 +306,19 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), }) } for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -253,13 +328,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -268,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Tag: `gorm:"-"`, }) - if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, + schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many @@ -315,31 +400,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } @@ -381,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", + schema, field.Name) } } @@ -389,34 +481,31 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { + if field.OwnerSchema == nil { reguessOrErr() return } + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { - if f := foreignSchema.LookUpField(foreignKey); f != nil { - foreignFields = append(foreignFields, f) - } else { + f := foreignSchema.LookUpField(foreignKey) + if f == nil { reguessOrErr() return } + foreignFields = append(foreignFields, f) } } else { - var primaryFields []*Field - var primarySchemaName = primarySchema.Name + primarySchemaName := primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name } @@ -431,6 +520,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu primaryFields = primarySchema.PrimaryFields } + primaryFieldLoop: for _, primaryField := range primaryFields { lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { @@ -439,23 +529,33 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { - lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", + strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, + strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } for _, name := range lookUpNames { if f := foreignSchema.LookUpField(name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) - break + continue primaryFieldLoop } } } } - if len(foreignFields) == 0 { + switch { + case len(foreignFields) == 0: reguessOrErr() return - } else if len(relation.primaryKeys) > 0 { + case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { @@ -469,7 +569,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu return } } - } else if len(primaryFields) == 0 { + case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { @@ -505,6 +605,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } +// Constraint is ForeignKey Constraint type Constraint struct { Name string Field *Field @@ -516,6 +617,31 @@ type Constraint struct { OnUpdate string } +func (constraint *Constraint) GetName() string { return constraint.Name } + +func (constraint *Constraint) Build() (sql string, vars []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + references := make([]interface{}, 0, len(constraint.References)) + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..23d79bbb 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model @@ -491,6 +518,319 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestEmbeddedHas(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + type User struct { + ID int + Cat struct { + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } `gorm:"embedded;embeddedPrefix:cat_"` + Dog struct { + ID int + Name string + UserID int + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) +} + +func TestPolymorphic(t *testing.T) { + t.Run("has one", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has one with only polymorphic type", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + Type string + } + + type Cat struct { + ID int + Name string + Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many", func(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) + + t.Run("has many with custom polymorphic type and id", func(t *testing.T) { + type Toy struct { + ID int + Name string + RefId int + Type string + } + + type Cat struct { + ID int + Name string + Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"` + } + + s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, + References: []Reference{ + {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) + }) +} + +func TestEmbeddedBelongsTo(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type Address struct { + CountryID int + Country Country + } + type NestedAddress struct { + Address + } + type Org struct { + ID int + PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID int + Address struct { + ID int + Address + } + NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Errorf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "PostalAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "VisitingAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "NestedAddress": { + EmbeddedRelations: map[string]EmbeddedRelations{ + "Address": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + }, + }, + }) +} + func TestVariableRelation(t *testing.T) { var result struct { User @@ -615,7 +955,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { s, err := schema.Parse( &Book{}, &sync.Map{}, - schema.NamingStrategy{}, + schema.NamingStrategy{IdentifierMaxLength: 64}, ) if err != nil { t.Fatalf("Failed to parse schema") diff --git a/schema/schema.go b/schema/schema.go index eca113e9..3e7459ce 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -6,12 +6,27 @@ import ( "fmt" "go/ast" "reflect" + "strings" "sync" "gorm.io/gorm/clause" "gorm.io/gorm/logger" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") @@ -25,6 +40,7 @@ type Schema struct { PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships @@ -67,10 +83,35 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + type Tabler interface { TableName() string } +type TablerWithNamer interface { + TableName(Namer) string +} + // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") @@ -112,7 +153,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam schemaCacheKey = modelType } - // Load exist schmema cache, return if exists + // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete @@ -125,6 +166,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { + tableName = tabler.TableName(namer) + } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } @@ -133,20 +177,21 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema := &Schema{ - Name: modelType.Name(), - ModelType: modelType, - Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, - cacheStore: cacheStore, - namer: namer, - initialized: make(chan struct{}), + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - // Load exist schmema cache, return if exists + // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete @@ -169,6 +214,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.DBName = namer.ColumnName(schema.Table, field.Name) } + bindName := field.BindName() if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { @@ -177,6 +223,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { @@ -195,6 +242,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } field.setupValuerAndSetter() } @@ -214,8 +264,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { - schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } } for _, field := range schema.PrimaryFields { @@ -223,7 +283,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } for _, field := range schema.Fields { - if field.HasDefaultValue && field.DefaultValueInterface == nil { + if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } } @@ -242,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} - for _, name := range callbacks { - if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + callbackTypes := []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, + } + for _, cbName := range callbackTypes { + if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } @@ -276,6 +342,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } else { schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field } } @@ -302,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 9abaecba..bc326686 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) } - for _, f := range relation.JoinTable.Fields { - checkSchemaField(t, r.JoinTable, &f, nil) + for i := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil) } } @@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { }) } +type EmbeddedRelations struct { + Relations map[string]Relation + EmbeddedRelations map[string]EmbeddedRelations +} + +func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { + for name, relations := range actual { + rs := expected[name] + t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { + if len(relations.Relations) != len(rs.Relations) { + t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) + } + if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { + t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) + } + for n, rel := range relations.Relations { + if r, ok := rs.Relations[n]; !ok { + t.Errorf("failed to find relation by name %s", n) + } else { + checkSchemaRelation(t, &schema.Schema{ + Relationships: schema.Relationships{ + Relations: map[string]*schema.Relationship{n: rel}, + }, + }, r) + } + } + checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) + }) + } +} + func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { diff --git a/schema/schema_test.go b/schema/schema_test.go index 8a752fb7..45e152e9 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -46,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true @@ -136,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) { {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, } - for _, f := range fields { - checkSchemaField(t, user, &f, func(f *schema.Field) { + for i := range fields { + checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true @@ -293,3 +293,44 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { }) } } + +func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { + type Product struct { + ProductID uint `gorm:"primaryKey;autoIncrement"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + type ProductNonAutoIncrement struct { + ProductID uint `gorm:"primaryKey;autoIncrement:false"` + LanguageCode uint `gorm:"primaryKey"` + Code string + Name string + } + + product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) + } + + prioritizedPrimaryField := schema.Field{ + Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"}, + } + + product.Fields = []*schema.Field{product.PrioritizedPrimaryField} + + checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + + productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err) + } + + if productNonAutoIncrement.PrioritizedPrimaryField != nil { + t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") + } +} diff --git a/schema/serializer.go b/schema/serializer.go index 758a6421..f500521e 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -70,8 +70,7 @@ type SerializerValuerInterface interface { } // JSONSerializer json serializer -type JSONSerializer struct { -} +type JSONSerializer struct{} // Scan implements serializer interface func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -88,7 +87,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } - err = json.Unmarshal(bytes, fieldValue.Interface()) + if len(bytes) > 0 { + err = json.Unmarshal(bytes, fieldValue.Interface()) + } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) @@ -98,12 +99,17 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) + if string(result) == "null" { + if field.TagSettings["NOT NULL"] != "" { + return "", nil + } + return nil, err + } return string(result), err } // UnixSecondSerializer json serializer -type UnixSecondSerializer struct { -} +type UnixSecondSerializer struct{} // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -117,9 +123,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect. // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { + rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { - case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: - result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) + case int64, int, uint, uint64, int32, uint32, int16, uint16: + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() + case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: + if rv.IsZero() { + return nil, nil + } + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } @@ -127,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } // GobSerializer gob serializer -type GobSerializer struct { -} +type GobSerializer struct{} // Scan implements serializer interface func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -142,8 +153,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, default: return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) } - decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) - err = decoder.Decode(fieldValue.Interface()) + if len(bytesValue) > 0 { + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) return diff --git a/schema/utils.go b/schema/utils.go index 2720c530..7fdda185 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -2,6 +2,7 @@ package schema import ( "context" + "fmt" "reflect" "regexp" "strings" @@ -59,6 +60,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct return tag } +func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { + t := tag.Get("gorm") + if strings.Contains(t, value) { + return tag + } + return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) +} + // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { @@ -106,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} @@ -124,7 +138,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, for i := 0; i < reflectValue.Len(); i++ { elem := reflectValue.Index(i) elemKey := elem.Interface() - if elem.Kind() != reflect.Ptr { + if elem.Kind() != reflect.Ptr && elem.CanAddr() { elemKey = elem.Addr().Interface() } diff --git a/soft_delete.go b/soft_delete.go index 6d646288..5673d3b8 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -6,6 +6,7 @@ import ( "encoding/json" "reflect" + "github.com/jinzhu/now" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -45,11 +46,21 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteQueryClause{Field: f}} + return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} +} + +func parseZeroValueTag(f *schema.Field) sql.NullString { + if v, ok := f.TagSettings["ZEROVALUE"]; ok { + if _, err := now.Parse(v); err == nil { + return sql.NullString{String: v, Valid: true} + } + } + return sql.NullString{Valid: false} } type SoftDeleteQueryClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteQueryClause) Name() string { @@ -78,18 +89,19 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } stmt.AddClause(clause.Where{Exprs: []clause.Expression{ - clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, }}) stmt.Clauses["soft_delete_enabled"] = clause.Clause{} } } func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteUpdateClause{Field: f}} + return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteUpdateClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteUpdateClause) Name() string { @@ -109,11 +121,12 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { } func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{SoftDeleteDeleteClause{Field: f}} + return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteDeleteClause struct { - Field *schema.Field + ZeroValue sql.NullString + Field *schema.Field } func (sd SoftDeleteDeleteClause) Name() string { diff --git a/statement.go b/statement.go index ed3e8716..ae79aa32 100644 --- a/statement.go +++ b/statement.go @@ -49,9 +49,12 @@ type Statement struct { } type join struct { - Name string - Conds []interface{} - On *clause.Where + Name string + Conds []interface{} + On *clause.Where + Selects []string + Omits []string + JoinType clause.JoinType } // StatementModifier statement modifier interface @@ -117,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { write(v.Raw, v.Name) @@ -179,6 +184,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } + case clause.Interface: + c := clause.Clause{Name: v.Name()} + v.MergeClause(&c) + c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: @@ -304,6 +313,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) for idx, arg := range args { + if arg == nil { + continue + } if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } @@ -312,9 +324,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - for _, scope := range v.Statement.scopes { - v = scope(v) - } + v.executeScopes() if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { @@ -437,8 +447,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + return []clause.Expression{clause.And(conds...)} } - return conds + return nil } } @@ -447,7 +458,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } - return conds + if len(conds) > 0 { + return []clause.Expression{clause.And(conds...)} + } + return nil } // Build build sql with clauses names @@ -540,8 +554,9 @@ func (stmt *Statement) clone() *Statement { } // SetColumn set column's value -// stmt.SetColumn("Name", "jinzhu") // Hooks Method -// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +// +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value @@ -650,54 +665,62 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) +var matchName = func() func(tableColumn string) (table, column string) { + nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) + return func(tableColumn string) (table, column string) { + if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { + table = matches[1] + star := matches[2] + columnName := matches[3] + if star != "" { + return table, star + } + return table, columnName + } + return "", "" + } +}() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} notRestricted := false - // select columns - for _, column := range stmt.Selects { + processColumn := func(column string, result bool) { if stmt.Schema == nil { - results[column] = true + results[column] = result } else if column == "*" { - notRestricted = true + notRestricted = result for _, dbName := range stmt.Schema.DBNames { - results[dbName] = true + results[dbName] = result } } else if column == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = true + results[rel.Name] = result } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { - results[field.DBName] = true - } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { - results[matches[1]] = true + results[field.DBName] = result + } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { + if col == "*" { + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = result + } + } else { + results[col] = result + } } else { - results[column] = true + results[column] = result } } + // select columns + for _, column := range stmt.Selects { + processColumn(column, true) + } + // omit columns - for _, omit := range stmt.Omits { - if stmt.Schema == nil { - results[omit] = false - } else if omit == "*" { - for _, dbName := range stmt.Schema.DBNames { - results[dbName] = false - } - } else if omit == clause.Associations { - for _, rel := range stmt.Schema.Relationships.Relations { - results[rel.Name] = false - } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { - results[field.DBName] = false - } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { - results[matches[1]] = false - } else { - results[omit] = false - } + for _, column := range stmt.Omits { + processColumn(column, false) } if stmt.Schema != nil { diff --git a/statement_test.go b/statement_test.go index 3f099d61..0995d547 100644 --- a/statement_test.go +++ b/statement_test.go @@ -35,15 +35,36 @@ func TestWhereCloneCorruption(t *testing.T) { } } +func TestNilCondition(t *testing.T) { + s := new(Statement) + if len(s.BuildCondition(nil)) != 0 { + t.Errorf("Nil condition should be empty") + } +} + func TestNameMatcher(t *testing.T) { - for k, v := range map[string]string{ - "table.name": "name", - "`table`.`name`": "name", - "'table'.'name'": "name", - "'table'.name": "name", + for k, v := range map[string][]string{ + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, + "my_table.*": {"my_table", "*"}, + "`my_table`.*": {"my_table", "*"}, + "User__Company.*": {"User__Company", "*"}, + "`User__Company`.*": {"User__Company", "*"}, + `"User__Company".*`: {"User__Company", "*"}, + `"table"."*"`: {"", ""}, } { - if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { - t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + if table, column := matchName(k); table != v[0] || column != v[1] { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) } } } diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go index f74799ce..103da032 100644 --- a/tests/associations_belongs_to_test.go +++ b/tests/associations_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -137,6 +138,7 @@ func TestBelongsToAssociation(t *testing.T) { unexistCompanyID := company.ID + 9999999 user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} if err := DB.Create(&user).Error; err == nil { + tidbSkip(t, "not support the foreign key feature") t.Errorf("should have gotten foreign key violation error") } } @@ -224,3 +226,81 @@ func TestBelongsToAssociationForSlice(t *testing.T) { AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } + +func TestBelongsToDefaultValue(t *testing.T) { + type Org struct { + ID string + } + type BelongsToUser struct { + OrgID string + Org Org `gorm:"default:NULL"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Config.DisableForeignKeyConstraintWhenMigrating = true + AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) + + tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) + tx.AutoMigrate(&BelongsToUser{}, &Org{}) + + user := &BelongsToUser{ + Org: Org{ + ID: "BelongsToUser_Org_1", + }, + } + err := DB.Create(&user).Error + AssertEqual(t, err, nil) +} + +func TestBelongsToAssociationUnscoped(t *testing.T) { + type ItemParent struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + } + type ItemChild struct { + gorm.Model + Name string `gorm:"type:varchar(50)"` + ItemParentID uint + ItemParent ItemParent + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) + tx.AutoMigrate(&ItemParent{}, &ItemChild{}) + + item := ItemChild{ + Name: "name", + ItemParent: ItemParent{ + Logo: "logo", + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + // test replace + if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ + Logo: "updated logo", + }); err != nil { + t.Errorf("failed to replace item parent, got error: %v", err) + } + + var parents []ItemParent + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 1 { + t.Errorf("expected %d parents, got %d", 1, len(parents)) + } + + // test delete + if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { + t.Errorf("failed to delete item parent, got error: %v", err) + } + if err := tx.Find(&parents).Error; err != nil { + t.Errorf("failed to find item parent, got error: %v", err) + } + if len(parents) != 0 { + t.Errorf("expected %d parents, got %d", 0, len(parents)) + } +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index 002ae636..b8e8ff5e 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -421,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), - *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-3", Config{Toys: 4}), } @@ -429,6 +430,7 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { // Count AssertAssociationCount(t, users, "Toys", 6, "") + AssertAssociationCount(t, users, "Tools", 2, "") // Find var toys []Toy @@ -436,6 +438,14 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { t.Errorf("toys count should be %v, but got %v", 6, len(toys)) } + // Find Tools (polymorphic with custom type and id) + var tools []Tools + DB.Model(&users).Association("Tools").Find(&tools) + + if len(tools) != 2 { + t.Errorf("tools count should be %v, but got %v", 2, len(tools)) + } + // Append DB.Model(&users).Association("Toys").Append( &Toy{Name: "toy-slice-append-1"}, @@ -471,3 +481,76 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Toys").Clear() AssertAssociationCount(t, users, "Toys", 0, "After Clear") } + +func TestHasManyAssociationUnscoped(t *testing.T) { + type ItemContent struct { + gorm.Model + ItemID uint `gorm:"not null"` + Name string `gorm:"not null;type:varchar(50)"` + LanguageCode string `gorm:"not null;type:varchar(2)"` + } + type Item struct { + gorm.Model + Logo string `gorm:"not null;type:varchar(50)"` + Contents []ItemContent `gorm:"foreignKey:ItemID"` + } + + tx := DB.Session(&gorm.Session{}) + tx.Migrator().DropTable(&ItemContent{}, &Item{}) + tx.AutoMigrate(&ItemContent{}, &Item{}) + + item := Item{ + Logo: "logo", + Contents: []ItemContent{ + {Name: "name", LanguageCode: "en"}, + {Name: "ar name", LanguageCode: "ar"}, + }, + } + if err := tx.Create(&item).Error; err != nil { + t.Fatalf("failed to create items, got error: %v", err) + } + + // test Replace + if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{ + {Name: "updated name", LanguageCode: "en"}, + {Name: "ar updated name", LanguageCode: "ar"}, + {Name: "le nom", LanguageCode: "fr"}, + }); err != nil { + t.Errorf("failed to replace item content, got error: %v", err) + } + + if count := tx.Model(&item).Association("Contents").Count(); count != 3 { + t.Errorf("expected %d contents, got %d", 3, count) + } + + var contents []ItemContent + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 3 { + t.Errorf("expected %d contents, got %d", 3, len(contents)) + } + + // test delete + if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil { + t.Errorf("failed to delete Contents, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 2 { + t.Errorf("expected %d contents, got %d", 2, count) + } + + // test clear + if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil { + t.Errorf("failed to clear contents association, got error: %v", err) + } + if count := tx.Model(&item).Association("Contents").Count(); count != 0 { + t.Errorf("expected %d contents, got %d", 0, count) + } + + if err := tx.Find(&contents).Error; err != nil { + t.Errorf("failed to find contents, got error: %v", err) + } + if len(contents) != 0 { + t.Errorf("expected %d contents, got %d", 0, len(contents)) + } +} diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 28b441bd..39410aed 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -1,8 +1,12 @@ package tests_test import ( + "fmt" + "sync" "testing" + "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -94,6 +98,8 @@ func TestMany2ManyAssociation(t *testing.T) { } func TestMany2ManyOmitAssociations(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + user := *GetUser("many2many_omit_associations", Config{Languages: 2}) if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { @@ -324,3 +330,96 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } + +func TestDuplicateMany2ManyAssociation(t *testing.T) { + user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-2"}, + }} + + user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ + {Code: "TestDuplicateMany2ManyAssociation-language-1"}, + {Code: "TestDuplicateMany2ManyAssociation-language-3"}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} + +func TestConcurrentMany2ManyAssociation(t *testing.T) { + db, err := OpenTestConnection(&gorm.Config{}) + if err != nil { + t.Fatalf("open test connection failed, err: %+v", err) + } + + count := 3 + + var languages []Language + for i := 0; i < count; i++ { + language := Language{Code: fmt.Sprintf("consurrent %d", i)} + db.Create(&language) + languages = append(languages, language) + } + + user := User{} + db.Create(&user) + db.Preload("Languages").FirstOrCreate(&user) + + var wg sync.WaitGroup + for i := 0; i < count; i++ { + wg.Add(1) + go func(user User, language Language) { + err := db.Model(&user).Association("Languages").Append(&language) + AssertEqual(t, err, nil) + + wg.Done() + }(user, languages[i]) + } + wg.Wait() + + var find User + err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error + AssertEqual(t, err, nil) + AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") +} + +func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { + user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + + user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +} diff --git a/tests/associations_test.go b/tests/associations_test.go index e729e979..4e8862e5 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -4,6 +4,8 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) @@ -69,6 +71,8 @@ func TestAssociationNotNullClear(t *testing.T) { } func TestForeignKeyConstraints(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + type Profile struct { ID uint Name string @@ -124,6 +128,8 @@ func TestForeignKeyConstraints(t *testing.T) { } func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + type Profile struct { ID uint Name string @@ -284,3 +290,107 @@ func TestAssociationError(t *testing.T) { err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) } + +type ( + myType string + emptyQueryClause struct { + Field *schema.Field + } +) + +func (myType) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{emptyQueryClause{Field: f}} +} + +func (sd emptyQueryClause) Name() string { + return "empty" +} + +func (sd emptyQueryClause) Build(clause.Builder) { +} + +func (sd emptyQueryClause) MergeClause(*clause.Clause) { +} + +func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) { + // do nothing +} + +func TestAssociationEmptyQueryClause(t *testing.T) { + type Organization struct { + gorm.Model + Name string + } + type Region struct { + gorm.Model + Name string + Organizations []Organization `gorm:"many2many:region_orgs;"` + } + type RegionOrg struct { + RegionId uint + OrganizationId uint + Empty myType + } + if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil { + t.Fatalf("Failed to set up join table, got error: %s", err) + } + if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %s", err) + } + if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + region := &Region{Name: "Region1"} + if err := DB.Create(region).Error; err != nil { + t.Fatalf("fail to create region %v", err) + } + var orgs []Organization + + if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil { + t.Fatalf("fail to find region organizations %v", err) + } else { + AssertEqual(t, len(orgs), 0) + } +} + +type AssociationEmptyUser struct { + ID uint + Name string + Pets []AssociationEmptyPet +} + +type AssociationEmptyPet struct { + AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` + Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` +} + +func TestAssociationEmptyPrimaryKey(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) + + id := uint(100) + user := AssociationEmptyUser{ + ID: id, + Name: "jinzhu", + Pets: []AssociationEmptyPet{ + {AssociationEmptyUserID: &id, Name: "bar"}, + {AssociationEmptyUserID: &id, Name: "foo"}, + }, + } + + err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error + if err != nil { + t.Fatalf("Failed to create, got error: %v", err) + } + + var result AssociationEmptyUser + err = DB.Preload("Pets").First(&result, &id).Error + if err != nil { + t.Fatalf("Failed to find, got error: %v", err) + } + + AssertEqual(t, result, user) +} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go index d897a634..22d15898 100644 --- a/tests/benchmark_test.go +++ b/tests/benchmark_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "fmt" "testing" . "gorm.io/gorm/utils/tests" @@ -24,6 +25,45 @@ func BenchmarkFind(b *testing.B) { } } +func BenchmarkScan(b *testing.B) { + user := *GetUser("scan", Config{}) + DB.Create(&user) + + var u User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users where id = ?", user.ID).Scan(&u) + } +} + +func BenchmarkScanSlice(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + +func BenchmarkScanSlicePointer(b *testing.B) { + DB.Exec("delete from users") + for i := 0; i < 10_000; i++ { + user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) + DB.Create(&user) + } + + var u []*User + b.ResetTimer() + for x := 0; x < b.N; x++ { + DB.Raw("select * from users").Scan(&u) + } +} + func BenchmarkUpdate(b *testing.B) { user := *GetUser("find", Config{}) DB.Create(&user) diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 02765b8c..f77209f1 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -38,6 +38,7 @@ func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} +func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { @@ -90,7 +91,7 @@ func TestCallbacks(t *testing.T) { }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, + results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, @@ -112,6 +113,9 @@ func TestCallbacks(t *testing.T) { for idx, data := range datas { db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } callbacks := db.Callback() for _, c := range data.callbacks { @@ -168,3 +172,83 @@ func TestCallbacks(t *testing.T) { } } } + +func TestPluginCallbacks(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("plugin_1_fn1", c1) + createCallback.After("*").Register("plugin_1_fn2", c2) + + if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 2 + createCallback.Before("*").Register("plugin_2_fn1", c3) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_2_fn2", c4) + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + // plugin 3 + createCallback.Before("*").Register("plugin_3_fn1", c5) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.After("*").Register("plugin_3_fn2", c6) + if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} + +func TestCallbacksGet(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { + t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) + } + + createCallback.Remove("c1") + if cb := createCallback.Get("c2"); cb != nil { + t.Errorf("callbacks test failed. got: %p, want: nil", cb) + } +} + +func TestCallbacksRemove(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + createCallback.After("*").Register("c2", c2) + createCallback.Before("c4").Register("c3", c3) + createCallback.After("c2").Register("c4", c4) + + // callbacks: []string{"c1", "c3", "c4", "c2"} + createCallback.Remove("c1") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c4") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c2") + if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c3") + if ok, msg := assertCallbacks(createCallback, []string{}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} diff --git a/tests/connpool_test.go b/tests/connpool_test.go index fbae2294..21a2bad0 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -48,9 +48,11 @@ func (c *wrapperConnPool) Ping() error { } // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. -// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { -// return c.db.BeginTx(ctx, opts) -// } +// +// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { +// return c.db.BeginTx(ctx, opts) +// } +// // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) @@ -100,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) { expect: []string{ "SELECT VERSION()", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", - "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", + "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", }, } @@ -116,7 +118,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) if err != nil { t.Fatalf("Should open db success, but got %v", err) } diff --git a/tests/count_test.go b/tests/count_test.go index b71e3de5..4449515b 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -11,6 +11,32 @@ import ( . "gorm.io/gorm/utils/tests" ) +func TestCountWithGroup(t *testing.T) { + DB.Create([]Company{ + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_a"}, + {Name: "company_count_group_b"}, + {Name: "company_count_group_c"}, + }) + + var count1 int64 + if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count1 != 1 { + t.Errorf("Count with group should be 1, but got count: %v", count1) + } + + var count2 int64 + if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + if count2 != 2 { + t.Errorf("Count with group should be 2, but got count: %v", count2) + } +} + func TestCount(t *testing.T) { var ( user1 = *GetUser("count-1", Config{}) @@ -142,7 +168,7 @@ func TestCount(t *testing.T) { DB.Create(sameUsers) if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { - t.Fatalf("Count should be 3, but got count: %v err %v", count11, err) + t.Fatalf("Count should be 1, but got count: %v err %v", count11, err) } var count12 int64 diff --git a/tests/create_test.go b/tests/create_test.go index 3730172f..abb82472 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -2,6 +2,7 @@ package tests_test import ( "errors" + "fmt" "regexp" "testing" "time" @@ -476,6 +477,13 @@ func TestOmitWithCreate(t *testing.T) { CheckUser(t, result2, user2) } +func TestFirstOrCreateNotExistsTable(t *testing.T) { + company := Company{Name: "first_or_create_if_not_exists_table"} + if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil { + t.Errorf("not exists table, but err is nil") + } +} + func TestFirstOrCreateWithPrimaryKey(t *testing.T) { company := Company{ID: 100, Name: "company100_with_primarykey"} DB.FirstOrCreate(&company) @@ -540,3 +548,246 @@ func TestFirstOrCreateRowsAffected(t *testing.T) { t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) } } + +func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { + type CompositeKeyProduct struct { + ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key + LanguageCode int `gorm:"primaryKey;"` // primary key + Code string + Name string + } + + if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + prod := &CompositeKeyProduct{ + LanguageCode: 56, + Code: "Code56", + Name: "ProductName56", + } + if err := DB.Create(&prod).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + newProd := &CompositeKeyProduct{} + if err := DB.First(&newProd).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name") + } +} + +func TestCreateOnConflictWithDefaultNull(t *testing.T) { + type OnConflictUser struct { + ID string + Name string `gorm:"default:null"` + Email string + Mobile string `gorm:"default:'133xxxx'"` + } + + err := DB.Migrator().DropTable(&OnConflictUser{}) + AssertEqual(t, err, nil) + err = DB.AutoMigrate(&OnConflictUser{}) + AssertEqual(t, err, nil) + + u := OnConflictUser{ + ID: "on-conflict-user-id", + Name: "on-conflict-user-name", + Email: "on-conflict-user-email", + Mobile: "on-conflict-user-mobile", + } + err = DB.Create(&u).Error + AssertEqual(t, err, nil) + + u.Name = "on-conflict-user-name-2" + u.Email = "on-conflict-user-email-2" + u.Mobile = "" + err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error + AssertEqual(t, err, nil) + + var u2 OnConflictUser + err = DB.Where("id = ?", u.ID).First(&u2).Error + AssertEqual(t, err, nil) + AssertEqual(t, u2.Name, "on-conflict-user-name-2") + AssertEqual(t, u2.Email, "on-conflict-user-email-2") + AssertEqual(t, u2.Mobile, "133xxxx") +} + +func TestCreateFromMapWithoutPK(t *testing.T) { + if !isMysql() { + t.Skipf("This test case skipped, because of only supporting for mysql") + } + + // case 1: one record, create from map[string]interface{} + mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} + if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue1["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + var idVal int64 + _, ok := mapValue1["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue1["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result1.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case2: one record, create from *map[string]interface{} + mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} + if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + if _, ok := mapValue2["id"]; !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + + _, ok = mapValue2["id"].(uint) + if ok { + t.Skipf("This test case skipped, because the db supports returning") + } + + idVal, ok = mapValue2["id"].(int64) + if !ok { + t.Fatal("ret result missing id") + } + + if int64(result2.ID) != idVal { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: records + values := []map[string]interface{}{ + {"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, + } + + beforeLen := len(values) + if err := DB.Model(&User{}).Create(&values).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + // mariadb with returning, values will be appended with id map + if len(values) == beforeLen*2 { + t.Skipf("This test case skipped, because the db supports returning") + } + + for i := range values { + v, ok := values[i]["id"] + if !ok { + t.Fatal("failed to create data from map with table, returning map has no primary key") + } + + var result User + if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { + t.Fatalf("failed to create from map, got error %v", err) + } + if int64(result.ID) != v.(int64) { + t.Fatal("failed to create data from map with table, @id != id") + } + } +} + +func TestCreateFromMapWithTable(t *testing.T) { + tableDB := DB.Table("users") + supportLastInsertID := isMysql() || isSqlite() + + // case 1: create from map[string]interface{} + record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} + if err := tableDB.Create(record).Error; err != nil { + t.Fatalf("failed to create data from map with table, got error: %v", err) + } + + if _, ok := record["@id"]; !ok && supportLastInsertID { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { + t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) + } + + // case 2: create from *map[string]interface{} + record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} + tableDB2 := DB.Table("users") + if err := tableDB2.Create(&record1).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + if _, ok := record1["@id"]; !ok && supportLastInsertID { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res1 map[string]interface{} + if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { + t.Fatalf("failed to create from map, got error %v", err) + } + + if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { + t.Fatal("failed to create data from map with table, @id != id") + } + + // case 3: create from []map[string]interface{} + records := []map[string]interface{}{ + {"name": "create_from_map_with_table_2", "age": 19}, + {"name": "create_from_map_with_table_3", "age": 20}, + } + + tableDB = DB.Table("users") + if err := tableDB.Create(&records).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + if _, ok := records[0]["@id"]; !ok && supportLastInsertID { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + if _, ok := records[1]["@id"]; !ok && supportLastInsertID { + t.Fatal("failed to create data from map with table, returning map has no key '@id'") + } + + var res2 map[string]interface{} + if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var res3 map[string]interface{} + if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) + } + + if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id") + } +} diff --git a/tests/delete_test.go b/tests/delete_test.go index 5cb4b91e..5d112b4e 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestSoftDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } @@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) { } func TestDeleteReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 9ab4ddb6..8abd4d0f 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -4,7 +4,7 @@ services: mysql: image: 'mysql/mysql-server:latest' ports: - - 9910:3306 + - "9910:3306" environment: - MYSQL_DATABASE=gorm - MYSQL_USER=gorm @@ -13,7 +13,7 @@ services: postgres: image: 'postgres:latest' ports: - - 9920:5432 + - "9920:5432" environment: - TZ=Asia/Shanghai - POSTGRES_DB=gorm @@ -22,10 +22,16 @@ services: mssql: image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' ports: - - 9930:1433 + - "9930:1433" environment: + - TZ=Asia/Shanghai - ACCEPT_EULA=Y - SA_PASSWORD=LoremIpsum86 - MSSQL_DB=gorm - MSSQL_USER=gorm - MSSQL_PASSWORD=LoremIpsum86 + tidb: + image: 'pingcap/tidb:v6.5.0' + ports: + - "9940:4000" + command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 312a5c37..873bba2a 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,7 +4,9 @@ import ( "database/sql/driver" "encoding/json" "errors" + "reflect" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -36,7 +38,7 @@ func TestEmbeddedStruct(t *testing.T) { type EngadgetPost struct { BasePost BasePost `gorm:"Embedded"` - Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct ImageUrl string } @@ -74,13 +76,26 @@ func TestEmbeddedStruct(t *testing.T) { t.Errorf("embedded struct's value should be scanned correctly") } - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) var egNews EngadgetPost if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") } + + var egPosts []EngadgetPost + if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { + t.Fatalf("no error should happen when query with embedded struct, but got %v", err) + } + expectAuthors := []string{"Edward", "George"} + for i, post := range egPosts { + t.Log(i, post.Author) + if want := expectAuthors[i]; post.Author.Name != want { + t.Errorf("expected author %s got %s", want, post.Author.Name) + } + } } func TestEmbeddedPointerTypeStruct(t *testing.T) { @@ -90,9 +105,21 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { URL string } + type Author struct { + ID string + Name string + Email string + Age int + Content Content + ContentPtr *Content + Birthday time.Time + BirthdayPtr *time.Time + } + type HNPost struct { *BasePost Upvotes int32 + *Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct } DB.Migrator().DropTable(&HNPost{}) @@ -110,6 +137,52 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) { if hnPost.Title != "embedded_pointer_type" { t.Errorf("Should find correct value for embedded pointer type") } + + if hnPost.Author != nil { + t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) + } + + now := time.Now().Round(time.Second) + NewPost := HNPost{ + BasePost: &BasePost{Title: "embedded_pointer_type2"}, + Author: &Author{ + Name: "test", + Content: Content{"test"}, + ContentPtr: nil, + Birthday: now, + BirthdayPtr: nil, + }, + } + DB.Create(&NewPost) + + hnPost = HNPost{} + if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != NewPost.Title { + t.Errorf("Should find correct value for embedded pointer type") + } + + if hnPost.Author.Name != NewPost.Author.Name { + t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) + } + + if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { + t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) + } + + if hnPost.Author.ContentPtr != nil { + t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) + } + + if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { + t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) + } + + if hnPost.Author.BirthdayPtr != nil { + t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) + } } type Content struct { @@ -117,18 +190,26 @@ type Content struct { } func (c Content) Value() (driver.Value, error) { - return json.Marshal(c) + // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, + b, err := json.Marshal(c) + return string(b[:]), err } func (c *Content) Scan(src interface{}) error { - b, ok := src.([]byte) - if !ok { - return errors.New("Embedded.Scan byte assertion failed") - } - var value Content - if err := json.Unmarshal(b, &value); err != nil { - return err + str, ok := src.(string) + if !ok { + byt, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + if err := json.Unmarshal(byt, &value); err != nil { + return err + } + } else { + if err := json.Unmarshal([]byte(str), &value); err != nil { + return err + } } *c = value @@ -155,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) { } func TestEmbeddedRelations(t *testing.T) { + type EmbUser struct { + gorm.Model + Name string + Age uint + Languages []Language `gorm:"many2many:EmbUserSpeak;"` + } + type AdvancedUser struct { - User `gorm:"embedded"` + EmbUser `gorm:"embedded"` Advanced bool } @@ -168,3 +256,29 @@ func TestEmbeddedRelations(t *testing.T) { } } } + +func TestEmbeddedTagSetting(t *testing.T) { + type Tag1 struct { + Id int64 `gorm:"autoIncrement"` + } + type Tag2 struct { + Id int64 + } + + type EmbeddedTag struct { + Tag1 Tag1 `gorm:"Embedded;"` + Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"` + Name string + } + + DB.Migrator().DropTable(&EmbeddedTag{}) + err := DB.Migrator().AutoMigrate(&EmbeddedTag{}) + AssertEqual(t, err, nil) + + t1 := EmbeddedTag{Name: "embedded_tag"} + err = DB.Save(&t1).Error + AssertEqual(t, err, nil) + if t1.Tag1.Id == 0 { + t.Errorf("embedded struct's primary field should be rewrited") + } +} diff --git a/tests/error_translator_test.go b/tests/error_translator_test.go new file mode 100644 index 00000000..ee54300e --- /dev/null +++ b/tests/error_translator_test.go @@ -0,0 +1,111 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +func TestDialectorWithErrorTranslatorSupport(t *testing.T) { + // it shouldn't translate error when the TranslateError flag is false + translatedErr := errors.New("translated error") + untranslatedErr := errors.New("some random error") + db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) + + err := db.AddError(untranslatedErr) + if !errors.Is(err, untranslatedErr) { + t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) + } + + // it should translate error when the TranslateError flag is true + db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true}) + + err = db.AddError(untranslatedErr) + if !errors.Is(err, translatedErr) { + t.Fatalf("expected err: %v got err: %v", translatedErr, err) + } +} + +func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}) + + if err = db.AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities table, got error: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if err != nil { + t.Fatalf("failed to create record: %v", err) + } + + err = db.Create(&City{Name: "Kabul"}).Error + if !errors.Is(err, gorm.ErrDuplicatedKey) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) + } +} + +func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { + tidbSkip(t, "not support the foreign key feature") + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + type Museum struct { + gorm.Model + Name string `gorm:"unique"` + CityID uint + City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` + } + + db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + + dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true} + if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { + return + } + + DB.Migrator().DropTable(&City{}, &Museum{}) + + if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { + t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) + } + + city := City{Name: "Amsterdam"} + + err = db.Create(&city).Error + if err != nil { + t.Fatalf("failed to create city: %v", err) + } + + err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error + if err != nil { + t.Fatalf("failed to create museum: %v", err) + } + + err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error + if !errors.Is(err, gorm.ErrForeignKeyViolated) { + t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) + } +} diff --git a/tests/go.mod b/tests/go.mod index 6a2cf22f..3d3901d9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,18 +1,39 @@ module gorm.io/gorm/tests -go 1.14 +go 1.18 require ( - github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.5 - golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect - gorm.io/driver/mysql v1.3.3 - gorm.io/driver/postgres v1.3.5 - gorm.io/driver/sqlite v1.3.2 - gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.4 + github.com/lib/pq v1.10.9 + github.com/stretchr/testify v1.9.0 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 + gorm.io/driver/sqlite v1.5.5 + gorm.io/driver/sqlserver v1.5.3 + gorm.io/gorm v1.25.8 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-sql-driver/mysql v1.8.0 // indirect + github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/microsoft/go-mssqldb v1.7.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 9827465c..4c31b88b 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -3,9 +3,19 @@ package tests_test import ( "testing" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) +func TestOpen(t *testing.T) { + dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc + _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err == nil { + t.Fatalf("should returns error but got nil") + } +} + func TestReturningWithNullToZeroValues(t *testing.T) { dialect := DB.Dialector.Name() switch dialect { diff --git a/tests/helper_test.go b/tests/helper_test.go index 7ee2a576..dc250b7c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -1,12 +1,15 @@ package tests_test import ( + "os" "sort" "strconv" "strings" "testing" "time" + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) @@ -20,6 +23,7 @@ type Config struct { Languages int Friends int NamedPet bool + Tools int } func GetUser(name string, config Config) *User { @@ -44,6 +48,10 @@ func GetUser(name string, config Config) *User { user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) } + for i := 0; i < config.Tools; i++ { + user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)}) + } + if config.Company { user.Company = Company{Name: "company-" + name} } @@ -73,13 +81,22 @@ func GetUser(name string, config Config) *User { return &user } +func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, true) +} + func CheckPet(t *testing.T, pet Pet, expect Pet) { + doCheckPet(t, pet, expect, false) +} + +func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) { if pet.ID != 0 { var newPet Pet - if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") } } @@ -92,17 +109,27 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) { } } +func CheckUserUnscoped(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, true) +} + func CheckUser(t *testing.T, user User, expect User) { + doCheckUser(t, user, expect, false) +} + +func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if user.ID != 0 { var newUser User - if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { - AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", + "CompanyID", "ManagerID", "Active") } } - AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", + "ManagerID", "Active") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") @@ -112,8 +139,9 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Account's foreign key should be saved") } else { var account Account - DB.First(&account, "user_id = ?", user.ID) - AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + db(unscoped).First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", + "Number") } } }) @@ -135,7 +163,7 @@ func CheckUser(t *testing.T, user User, expect User) { if pet == nil || expect.Pets[idx] == nil { t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) } else { - CheckPet(t, *pet, *expect.Pets[idx]) + doCheckPet(t, *pet, *expect.Pets[idx], unscoped) } } }) @@ -172,8 +200,11 @@ func CheckUser(t *testing.T, user User, expect User) { t.Errorf("Manager's foreign key should be saved") } else { var manager User - DB.First(&manager, "id = ?", *user.ManagerID) - AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + db(unscoped).First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) @@ -194,7 +225,8 @@ func CheckUser(t *testing.T, user User, expect User) { }) for idx, team := range user.Team { - AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) @@ -229,7 +261,34 @@ func CheckUser(t *testing.T, user User, expect User) { }) for idx, friend := range user.Friends { - AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", + "Birthday", "CompanyID", "ManagerID", "Active") } }) } + +func tidbSkip(t *testing.T, reason string) { + if isTiDB() { + t.Skipf("This test case skipped, because of TiDB '%s'", reason) + } +} + +func isTiDB() bool { + return os.Getenv("GORM_DIALECT") == "tidb" +} + +func isMysql() bool { + return os.Getenv("GORM_DIALECT") == "mysql" +} + +func isSqlite() bool { + return os.Getenv("GORM_DIALECT") == "sqlite" +} + +func db(unscoped bool) *gorm.DB { + if unscoped { + return DB.Unscoped() + } else { + return DB + } +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 20e8dc18..0753dd0b 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -466,8 +466,9 @@ type Product4 struct { type ProductItem struct { gorm.Model - Code string - Product4ID uint + Code string + Product4ID uint + AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { @@ -477,6 +478,11 @@ func (pi ProductItem) BeforeCreate(*gorm.DB) error { return nil } +func (pi *ProductItem) AfterFind(*gorm.DB) error { + pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 + return nil +} + func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) @@ -498,4 +504,65 @@ func TestFailedToSaveAssociationShouldRollback(t *testing.T) { if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } + + var productWithItem Product4 + if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } + + if productWithItem.Item.AfterFindCallTimes != 0 { + t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) + } +} + +type Product5 struct { + gorm.Model + Name string +} + +var beforeUpdateCall int + +func (p *Product5) BeforeUpdate(*gorm.DB) error { + beforeUpdateCall = beforeUpdateCall + 1 + return nil +} + +func TestUpdateCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product5{}) + DB.AutoMigrate(&Product5{}) + + p := Product5{Name: "unique_code"} + DB.Model(&Product5{}).Create(&p) + + err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should be called") + } + + err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 1 { + t.Fatalf("before update should not be called") + } + + err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error + if err != nil { + t.Fatalf("should update success, but got err %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should be called") + } + + err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error + if !errors.Is(err, gorm.ErrInvalidValue) { + t.Fatalf("should got RecordNotFound, but got %v", err) + } + if beforeUpdateCall != 2 { + t.Fatalf("before update should not be called") + } } diff --git a/tests/joins_test.go b/tests/joins_test.go index 4908e5ba..786fc37e 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -229,3 +229,174 @@ func TestJoinWithSoftDeleted(t *testing.T) { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } + +func TestInnerJoins(t *testing.T) { + user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) + + DB.Create(&user) + + var user2 User + var err error + err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user2, user) + + // inner join and NamedPet is nil + err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error + AssertEqual(t, err, gorm.ErrRecordNotFound) + + // mixed inner join and left join + var user3 User + err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error + AssertEqual(t, err, nil) + CheckUser(t, user3, user) +} + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} + +func TestJoinArgsWithDB(t *testing.T) { + user := *GetUser("joins-args-db", Config{Pets: 2}) + DB.Save(&user) + + // test where + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") + + // test where and omit + onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) + AssertEqual(t, user2.NamedPet.Name, "") + + // test where and select + onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") + var user3 User + if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user3.NamedPet.ID, 0) + AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") + + // test select + onQuery4 := DB.Select("ID") + var user4 User + if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + if user4.NamedPet.ID == 0 { + t.Fatal("Pet ID can not be empty") + } + AssertEqual(t, user4.NamedPet.Name, "") +} + +func TestNestedJoins(t *testing.T) { + users := []User{ + { + Name: "nested-joins-1", + Manager: &User{ + Name: "nested-joins-manager-1", + Company: Company{ + Name: "nested-joins-manager-company-1", + }, + NamedPet: &Pet{ + Name: "nested-joins-manager-namepet-1", + Toy: Toy{ + Name: "nested-joins-manager-namepet-toy-1", + }, + }, + }, + NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, + }, + { + Name: "nested-joins-2", + Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), + NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, + }, + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB. + Joins("Manager"). + Joins("Manager.Company"). + Joins("Manager.NamedPet"). + Joins("Manager.NamedPet.Toy"). + Joins("NamedPet"). + Joins("NamedPet.Toy"). + Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + // user + CheckUser(t, user, users2[idx]) + if users2[idx].Manager == nil { + t.Fatalf("Failed to load Manager") + } + // manager + CheckUser(t, *user.Manager, *users2[idx].Manager) + // user pet + if users2[idx].NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) + // manager pet + if users2[idx].Manager.NamedPet == nil { + t.Fatalf("Failed to load NamedPet") + } + CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index d6a6c4db..d955c8d7 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1,19 +1,31 @@ package tests_test import ( + "context" + "database/sql" + "fmt" "math/rand" + "os" "reflect" + "strconv" "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) func TestMigrate(t *testing.T) { - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") @@ -23,13 +35,13 @@ func TestMigrate(t *testing.T) { } if err := DB.AutoMigrate(allModels...); err != nil { - t.Fatalf("Failed to auto migrate, but got error %v", err) + t.Fatalf("Failed to auto migrate, got error %v", err) } if tables, err := DB.Migrator().GetTables(); err != nil { t.Fatalf("Failed to get database all tables, but got error %v", err) } else { - for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} { + for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} { hasTable := false for _, t2 := range tables { if t2 == t1 { @@ -72,6 +84,44 @@ func TestMigrate(t *testing.T) { } } +func TestAutoMigrateInt8PG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type Smallint int8 + + type MigrateInt struct { + Int8 Smallint + } + + tracer := Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { + t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", + sql) + } + }, + } + + DB.Migrator().DropTable(&MigrateInt{}) + + // The first AutoMigrate to make table with field with correct type + if err := DB.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } + + // make new session to set custom logger tracer + session := DB.Session(&gorm.Session{Logger: tracer}) + + // The second AutoMigrate to catch an error + if err := session.AutoMigrate(&MigrateInt{}); err != nil { + t.Fatalf("Failed to auto migrate: error: %v", err) + } +} + func TestAutoMigrateSelfReferential(t *testing.T) { type MigratePerson struct { ID uint @@ -214,9 +264,10 @@ func TestMigrateWithIndexComment(t *testing.T) { func TestMigrateWithUniqueIndex(t *testing.T) { type UserWithUniqueIndex struct { - ID int - Name string `gorm:"size:20;index:idx_name,unique"` - Date time.Time `gorm:"index:idx_name,unique"` + ID int + Name string `gorm:"size:20;index:idx_name,unique"` + Date time.Time `gorm:"index:idx_name,unique"` + UName string `gorm:"uniqueIndex;size:255"` } DB.Migrator().DropTable(&UserWithUniqueIndex{}) @@ -227,6 +278,18 @@ func TestMigrateWithUniqueIndex(t *testing.T) { if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { t.Errorf("Failed to find created index") } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } + + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { + t.Errorf("Failed to find created index") + } } func TestMigrateTable(t *testing.T) { @@ -262,6 +325,25 @@ func TestMigrateTable(t *testing.T) { } } +func TestMigrateWithQuotedIndex(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type QuotedIndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index:AS"` // AS is one of MySQL reserved words + } + + if err := DB.Migrator().DropTable(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&QuotedIndexStruct{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model @@ -312,7 +394,148 @@ func TestMigrateIndexes(t *testing.T) { } } +func TestTiDBMigrateColumns(t *testing.T) { + if !isTiDB() { + t.Skip() + } + + // TiDB can't change column constraint and has auto_random feature + type ColumnStruct struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type ColumnStruct2 struct { + ID int `gorm:"primarykey;default:auto_random()"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"comment:my code2;default:hello"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { + t.Fatalf("no error should happened when alter column, but got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + case "name": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + if length, ok := columnType.Length(); !ok || length != 100 { + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) + } + case "age": + if v, ok := columnType.DefaultValue(); !ok || v != "18" { + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + if v, ok := columnType.Comment(); !ok || v != "my age" { + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + case "code": + if v, ok := columnType.Unique(); !ok || !v { + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + case "code2": + // Code2 string `gorm:"comment:my code2;default:hello"` + if v, ok := columnType.DefaultValue(); !ok || v != "hello" { + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) + } + if v, ok := columnType.Comment(); !ok || v != "my code2" { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + } + } + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") + } +} + func TestMigrateColumns(t *testing.T) { + tidbSkip(t, "use another test case") + sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" @@ -357,36 +580,45 @@ func TestMigrateColumns(t *testing.T) { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { - t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { - t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { - t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) + t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { - t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { - t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { - t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { - t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", + columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { - t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code2": if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { - t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), + columnType) } case "code3": // TODO @@ -423,7 +655,8 @@ func TestMigrateColumns(t *testing.T) { t.Fatalf("Failed to add column, got %v", err) } - if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", + "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } @@ -636,3 +869,1079 @@ func TestMigrateSerialColumn(t *testing.T) { AssertEqual(t, v.ID, v.UID) } } + +// https://github.com/go-gorm/gorm/issues/5300 +func TestMigrateWithSpecialName(t *testing.T) { + var err error + err = DB.AutoMigrate(&Coupon{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + AssertEqual(t, true, DB.Migrator().HasTable("coupons")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) + AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) +} + +// https://github.com/go-gorm/gorm/issues/4760 +func TestMigrateAutoIncrement(t *testing.T) { + type AutoIncrementStruct struct { + ID int64 `gorm:"primarykey;autoIncrement"` + Field1 uint32 `gorm:"column:field1"` + Field2 float32 `gorm:"column:field2"` + } + + if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil { + t.Fatalf("AutoMigrate err: %v", err) + } + + const ROWS = 10 + for idx := 0; idx < ROWS; idx++ { + if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + } + + rows := make([]*AutoIncrementStruct, 0, ROWS) + if err := DB.Order("id ASC").Find(&rows).Error; err != nil { + t.Fatalf("find auto_increment_struct fail, err: %v", err) + } + + ids := make([]int64, 0, len(rows)) + for _, row := range rows { + ids = append(ids, row.ID) + } + lastID := ids[len(ids)-1] + + if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil { + t.Fatalf("delete auto_increment_struct fail, err: %v", err) + } + + newRow := &AutoIncrementStruct{} + if err := DB.Create(newRow).Error; err != nil { + t.Fatalf("create auto_increment_struct fail, err: %v", err) + } + + AssertEqual(t, newRow.ID, lastID+1) +} + +// https://github.com/go-gorm/gorm/issues/5320 +func TestPrimarykeyID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type MissPKLanguage struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + Name string + } + + type MissPKUser struct { + ID string `gorm:"type:uuid;default:uuid_generate_v4()"` + MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` + } + + var err error + err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("DropTable err:%v", err) + } + + DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) + + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // patch + err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } +} + +func TestCurrentTimestamp(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + type CurrentTimestampTest struct { + ID string `gorm:"primary_key"` + TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` + } + var err error + err = DB.Migrator().DropTable(&CurrentTimestampTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + err = DB.AutoMigrate(&CurrentTimestampTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) +} + +func TestUniqueColumn(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + return + } + + type UniqueTest struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique"` + } + + type UniqueTest2 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:NULL"` + } + + type UniqueTest3 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:''"` + } + + type UniqueTest4 struct { + ID string `gorm:"primary_key"` + Name string `gorm:"unique;default:'123'"` + } + + var err error + err = DB.Migrator().DropTable(&UniqueTest{}) + if err != nil { + t.Errorf("DropTable err:%v", err) + } + + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // null -> null + err = DB.AutoMigrate(&UniqueTest{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok := ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + // null -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + // not trigger alert column + AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) + + tidbSkip(t, "can't change column constraint") + + // null -> empty string + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, true, ok) + + // empty string -> 123 + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "123", value) + AssertEqual(t, true, ok) + + // 123 -> null + err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) + if err != nil { + t.Fatalf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&UniqueTest{}, "name") + if err != nil { + t.Fatalf("findColumnType err:%v", err) + } + + value, ok = ct.DefaultValue() + AssertEqual(t, "", value) + AssertEqual(t, false, ok) +} + +func findColumnType(dest interface{}, columnName string) ( + foundColumn gorm.ColumnType, err error, +) { + columnTypes, err := DB.Migrator().ColumnTypes(dest) + if err != nil { + err = fmt.Errorf("ColumnTypes err:%v", err) + return + } + + for _, c := range columnTypes { + if c.Name() == columnName { + foundColumn = c + break + } + } + return +} + +func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + if err != nil { + t.Errorf("Open err:%v", err) + } + + type Object1 struct{} + type Object2 struct { + Field1 string + } + type Object3 struct { + Field2 string + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } +} + +func TestInvalidCachedPlanPrepareStmt(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) + if err != nil { + t.Errorf("Open err:%v", err) + } + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + type Object1 struct { + ID uint + } + type Object2 struct { + ID uint + Field1 int `gorm:"type:int8"` + } + type Object3 struct { + ID uint + Field1 int `gorm:"type:int4"` + } + type Object4 struct { + ID uint + Field2 int + } + db.Migrator().DropTable("objects") + + err = db.Table("objects").AutoMigrate(&Object1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + err = db.Table("objects").Create(&Object1{}).Error + if err != nil { + t.Errorf("create err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object2{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object2{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AlterColumn + err = db.Table("objects").AutoMigrate(&Object3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object3{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + // AddColumn + err = db.Table("objects").AutoMigrate(&Object4{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } + + db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") + if err != nil { + t.Errorf("RenameColumn err:%v", err) + } + + err = db.Table("objects").Take(&Object4{}).Error + if err != nil { + t.Errorf("take err:%v", err) + } +} + +func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { + type DiffType struct { + ID uint + Name string `gorm:"type:varchar(20)"` + } + + type DiffType1 struct { + ID uint + Name string `gorm:"type:text"` + } + + var err error + DB.Migrator().DropTable(&DiffType{}) + + err = DB.AutoMigrate(&DiffType{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err := findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) + + err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + ct, err = findColumnType(&DiffType{}, "name") + if err != nil { + t.Errorf("findColumnType err:%v", err) + } + + AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) +} + +func TestMigrateArrayTypeModel(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ArrayTypeModel struct { + ID uint + Number string `gorm:"type:varchar(51);NOT NULL"` + TextArray []string `gorm:"type:text[];NOT NULL"` + NestedTextArray [][]string `gorm:"type:text[][]"` + NestedIntArray [][]int64 `gorm:"type:integer[3][3]"` + } + + var err error + DB.Migrator().DropTable(&ArrayTypeModel{}) + + err = DB.AutoMigrate(&ArrayTypeModel{}) + AssertEqual(t, nil, err) + + ct, err := findColumnType(&ArrayTypeModel{}, "number") + AssertEqual(t, nil, err) + AssertEqual(t, "varchar", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_text_array") + AssertEqual(t, nil, err) + AssertEqual(t, "text[]", ct.DatabaseTypeName()) + + ct, err = findColumnType(&ArrayTypeModel{}, "nested_int_array") + AssertEqual(t, nil, err) + AssertEqual(t, "integer[]", ct.DatabaseTypeName()) +} + +type mockMigrator struct { + gorm.Migrator +} + +func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { + err := mm.Migrator.AlterColumn(dst, field) + if err != nil { + return err + } + return fmt.Errorf("trigger alter column error, field: %s", field) +} + +func TestMigrateDonotAlterColumn(t *testing.T) { + wrapMockMigrator := func(m gorm.Migrator) mockMigrator { + return mockMigrator{ + Migrator: m, + } + } + m := DB.Migrator() + mockM := wrapMockMigrator(m) + + type NotTriggerUpdate struct { + ID uint + F1 uint16 + F2 uint32 + F3 int + F4 int64 + F5 string + F6 float32 + F7 float64 + F8 time.Time + F9 bool + F10 []byte + } + + var err error + err = mockM.DropTable(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) + err = mockM.AutoMigrate(&NotTriggerUpdate{}) + AssertEqual(t, err, nil) +} + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count") + AssertEqual(t, nil, err) +} + +func TestMigrateWithDefaultValue(t *testing.T) { + if DB.Dialector.Name() == "sqlserver" { + // sqlserver driver treats NULL and 'NULL' the same + t.Skip("skip sqlserver") + } + + type NullModel struct { + ID uint + Content string `gorm:"default:null"` + } + + type NullStringModel struct { + ID uint + Content string `gorm:"default:'null'"` + Active bool `gorm:"default:false"` + } + + tableName := "null_string_model" + + DB.Migrator().DropTable(tableName) + + err := DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + // default null -> 'null' + err = DB.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err := findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok := columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + columnType2, err := findColumnType(tableName, "active") + AssertEqual(t, err, nil) + + defVal, ok = columnType2.DefaultValue() + bv, _ := strconv.ParseBool(defVal) + AssertEqual(t, bv, false) + AssertEqual(t, ok, true) + + // default 'null' -> 'null' + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + err = session.Table(tableName).AutoMigrate(&NullStringModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "null") + AssertEqual(t, ok, true) + + // default 'null' -> null + err = DB.Table(tableName).AutoMigrate(&NullModel{}) + AssertEqual(t, err, nil) + + columnType, err = findColumnType(tableName, "content") + AssertEqual(t, err, nil) + + defVal, ok = columnType.DefaultValue() + AssertEqual(t, defVal, "") + AssertEqual(t, ok, false) +} + +func TestMigrateMySQLWithCustomizedTypes(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type MyTable struct { + Def string `gorm:"size:512;index:idx_def,unique"` + Abc string `gorm:"size:65000000"` + } + + DB.Migrator().DropTable("my_tables") + + sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))" + if err := DB.Exec(sql).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + session := DB.Session(&gorm.Session{Logger: Tracer{ + Logger: DB.Config.Logger, + Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sql, _ := fc() + if strings.HasPrefix(sql, "ALTER TABLE") { + t.Errorf("shouldn't execute: sql=%s", sql) + } + }, + }}) + + if err := session.AutoMigrate(&MyTable{}); err != nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestMigrateIgnoreRelations(t *testing.T) { + type RelationModel1 struct { + ID uint + } + type RelationModel2 struct { + ID uint + } + type RelationModel3 struct { + ID uint + RelationModel1ID uint + RelationModel1 *RelationModel1 + RelationModel2ID uint + RelationModel2 *RelationModel2 `gorm:"-:migration"` + } + + var err error + _ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{}) + + tx := DB.Session(&gorm.Session{}) + tx.IgnoreRelationshipsWhenMigrating = true + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should not be existed + _, err = findColumnType(&RelationModel1{}, "id") + if err == nil { + t.Errorf("RelationModel1 should not be migrated") + } + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } + + tx.IgnoreRelationshipsWhenMigrating = false + + err = tx.AutoMigrate(&RelationModel3{}) + if err != nil { + t.Errorf("AutoMigrate err:%v", err) + } + + // RelationModel3 should be existed + _, err = findColumnType(&RelationModel3{}, "id") + AssertEqual(t, nil, err) + + // RelationModel1 should be existed + _, err = findColumnType(&RelationModel1{}, "id") + AssertEqual(t, nil, err) + + // RelationModel2 should not be existed + _, err = findColumnType(&RelationModel2{}, "id") + if err == nil { + t.Errorf("RelationModel2 should not be migrated") + } +} + +func TestMigrateView(t *testing.T) { + DB.Save(GetUser("joins-args-db", Config{Pets: 2})) + + if err := DB.Migrator().CreateView("invalid_users_pets", + gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { + t.Fatalf("no view should be created, got %v", err) + } + + query := DB.Model(&User{}). + Select("users.id as users_id, users.name as users_name, pets.id as pets_id, pets.name as pets_name"). + Joins("inner join pets on pets.user_id = users.id") + + if err := DB.Migrator().CreateView("users_pets", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + + var count int64 + if err := DB.Table("users_pets").Count(&count).Error; err != nil { + t.Fatalf("should found created view") + } + + if err := DB.Migrator().DropView("users_pets"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } + + query = DB.Model(&User{}).Where("age > ?", 20) + if err := DB.Migrator().CreateView("users_view", gorm.ViewOption{Query: query}); err != nil { + t.Fatalf("Failed to crate view, got %v", err) + } + if err := DB.Migrator().DropView("users_view"); err != nil { + t.Fatalf("Failed to drop view, got %v", err) + } +} + +func TestMigrateExistingBoolColumnPG(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type ColumnStruct struct { + gorm.Model + Name string + StringBool string + SmallintBool int `gorm:"type:smallint"` + } + + type ColumnStruct2 struct { + gorm.Model + Name string + StringBool bool // change existing boolean column from string to boolean + SmallintBool bool // change existing boolean column from smallint or other to boolean + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "id": + if v, ok := columnType.PrimaryKey(); !ok || !v { + t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), + columnType) + } + case "string_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + case "smallint_bool": + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", + columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) + } + } + } + } +} + +func TestTableType(t *testing.T) { + // currently it is only supported for mysql driver + if !isMysql() { + return + } + + const tblName = "cities" + const tblSchema = "gorm" + const tblType = "BASE TABLE" + const tblComment = "foobar comment" + + type City struct { + gorm.Model + Name string `gorm:"unique"` + } + + DB.Migrator().DropTable(&City{}) + + if err := DB.Set("gorm:table_options", + fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { + t.Fatalf("failed to migrate cities tables, got error: %v", err) + } + + tableType, err := DB.Table("cities").Migrator().TableType(&City{}) + if err != nil { + t.Fatalf("failed to get table type, got error %v", err) + } + + if tableType.Schema() != tblSchema { + t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema()) + } + + if tableType.Name() != tblName { + t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name()) + } + + if tableType.Type() != tblType { + t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type()) + } + + comment, ok := tableType.Comment() + if !ok || comment != tblComment { + t.Fatalf("expected comment %s got %s", tblComment, comment) + } +} + +func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { + const table = "unique_struct" + + checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) { + stmt := &gorm.Statement{DB: DB} + err := stmt.Parse(model) + if err != nil { + t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err) + } + _ = stmt.Schema.ParseIndexes() + field := stmt.Schema.LookUpField(fieldName) + if field == nil { + t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName) + } + if field.Unique != unique { + t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique) + } + if field.UniqueIndex != uniqueIndex { + t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex) + } + } + + type ( // not unique + UniqueStruct1 struct { + Name string `gorm:"size:10"` + } + UniqueStruct2 struct { + Name string `gorm:"size:20"` + } + ) + checkField(&UniqueStruct1{}, "name", false, "") + checkField(&UniqueStruct2{}, "name", false, "") + + type ( // unique + UniqueStruct3 struct { + Name string `gorm:"size:30;unique"` + } + UniqueStruct4 struct { + Name string `gorm:"size:40;unique"` + } + ) + checkField(&UniqueStruct3{}, "name", true, "") + checkField(&UniqueStruct4{}, "name", true, "") + + type ( // uniqueIndex + UniqueStruct5 struct { + Name string `gorm:"size:50;uniqueIndex"` + } + UniqueStruct6 struct { + Name string `gorm:"size:60;uniqueIndex"` + } + UniqueStruct7 struct { + Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + } + ) + checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name") + checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name") + + checkField(&UniqueStruct7{}, "name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + checkField(&UniqueStruct7{}, "nick_name", false, "") + + type UniqueStruct8 struct { // unique and uniqueIndex + Name string `gorm:"size:60;unique;index:my_us8_index,unique;"` + } + checkField(&UniqueStruct8{}, "name", true, "my_us8_index") + + type TestCase struct { + name string + from, to interface{} + checkFunc func(t *testing.T) + } + + checkColumnType := func(t *testing.T, fieldName string, unique bool) { + columnTypes, err := DB.Migrator().ColumnTypes(table) + if err != nil { + t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err) + } + var found gorm.ColumnType + for _, columnType := range columnTypes { + if columnType.Name() == fieldName { + found = columnType + } + } + if found == nil { + t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName) + } + if actualUnique, ok := found.Unique(); !ok || actualUnique != unique { + t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique) + } + } + + checkIndex := func(t *testing.T, expected []gorm.Index) { + indexes, err := DB.Migrator().GetIndexes(table) + if err != nil { + t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err) + } + assert.ElementsMatch(t, expected, indexes) + } + + uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + + var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T) + // UniqueAffectedByUniqueIndex is true + if DB.Dialector.Name() == "mysql" { + uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} + checkNotUnique = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, nil) + } + checkUnique = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex}) + } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex}) + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + checkIndex(t, []gorm.Index{mulIndex}) + } + } else { + checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) } + checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + checkMyIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + if !DB.Migrator().HasIndex(table, myIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name()) + } + } + checkMulIndex = func(t *testing.T) { + checkColumnType(t, "name", false) + checkColumnType(t, "nick_name", false) + if !DB.Migrator().HasIndex(table, mulIndex.Name()) { + t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name()) + } + } + } + + tests := []TestCase{ + {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique}, + {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex}, + {name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, + {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.from); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + + if DB.Dialector.Name() != "sqlserver" { + // In SQLServer, If an index or constraint depends on the column, + // this column will not be able to run ALTER + // see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205 + // may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106 + tests = []TestCase{ + {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + } + } + + if DB.Dialector.Name() == "mysql" { + compatibilityTests := []TestCase{ + {name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique}, + {name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex}, + } + for _, test := range compatibilityTests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil { + t.Fatalf("failed to create table, got error: %v", err) + } + if err := DB.Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + } +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 66b988c3..44cac6bf 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -7,8 +7,60 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" ) +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() @@ -60,16 +112,55 @@ func TestPostgres(t *testing.T) { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } + + DB.Migrator().DropTable("log_usage") + + if err := DB.Exec(` +CREATE TABLE public.log_usage ( + log_id bigint NOT NULL +); + +ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( + SEQUENCE NAME public.log_usage_log_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1 +); + `).Error; err != nil { + t.Fatalf("failed to create table, got error %v", err) + } + + columns, err := DB.Migrator().ColumnTypes("log_usage") + if err != nil { + t.Fatalf("failed to get columns, got error %v", err) + } + + hasLogID := false + for _, column := range columns { + if column.Name() == "log_id" { + hasLogID = true + autoIncrement, ok := column.AutoIncrement() + if !ok || !autoIncrement { + t.Fatalf("column log_id should be auto incrementment") + } + } + } + + if !hasLogID { + t.Fatalf("failed to found column log_id") + } } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } @@ -98,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPostgresOnConstraint(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Thing struct { + gorm.Model + SomeID string + OtherID string + Data string + } + + DB.Migrator().DropTable(&Thing{}) + DB.Migrator().CreateTable(&Thing{}) + if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { + t.Error(err) + } + + thing := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something", + } + + DB.Create(&thing) + + thing2 := Thing{ + SomeID: "1234", + OtherID: "1234", + Data: "something else", + } + + result := DB.Clauses(clause.OnConflict{ + OnConstraint: "some_id_other_id_unique", + UpdateAll: true, + }).Create(&thing2) + if result.Error != nil { + t.Errorf("creating second thing: %v", result.Error) + } + + var things []Thing + if err := DB.Find(&things).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } + + if len(things) > 1 { + t.Errorf("expected 1 thing got more") + } +} + +type CompanyNew struct { + ID int + Name int +} + +func TestAlterColumnDataType(t *testing.T) { + DB.AutoMigrate(Company{}) + + if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { + t.Fatalf("failed to alter column from string to int, got error %v", err) + } + + DB.AutoMigrate(Company{}) +} diff --git a/tests/preload_test.go b/tests/preload_test.go index cb4343ec..14f94139 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/require" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -269,3 +271,242 @@ func TestPreloadWithDiffModel(t *testing.T) { CheckUser(t, user, result.User) } + +func TestNestedPreloadWithUnscoped(t *testing.T) { + user := *GetUser("nested_preload", Config{Pets: 1}) + pet := user.Pets[0] + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)} + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + DB.Delete(&pet) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + if len(user3.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + if len(user4.Pets) != 0 { + t.Fatalf("User.Pet[0] was deleted and should not exist.") + } + + var user5 User + DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID) + CheckUserUnscoped(t, user5, user) + + var user6 *User + DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) + CheckUserUnscoped(t, *user6, user) +} + +func TestNestedPreloadWithNestedJoin(t *testing.T) { + type ( + Preload struct { + ID uint + Value string + NestedID uint + } + Join struct { + ID uint + Value string + NestedID uint + } + Nested struct { + ID uint + Preloads []*Preload + Join Join + ValueID uint + } + Value struct { + ID uint + Name string + Nested Nested + } + ) + + DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) + DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) + + value := Value{ + Name: "value", + Nested: Nested{ + Preloads: []*Preload{ + {Value: "p1"}, {Value: "p2"}, + }, + Join: Join{Value: "j1"}, + }, + } + if err := DB.Create(&value).Error; err != nil { + t.Errorf("failed to create value, got err: %v", err) + } + + var find1 Value + err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find1, value) + + var find2 Value + // Joins will automatically add Nested queries. + err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find2, value) + + var finds []Value + err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + require.Len(t, finds, 1) + AssertEqual(t, finds[0], value) +} + +func TestEmbedPreload(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type EmbeddedAddress struct { + ID int + Name string + CountryID *int + Country *Country + } + type NestedAddress struct { + EmbeddedAddress + } + type Org struct { + ID int + PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID *int + Address *EmbeddedAddress + NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) + DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) + + org := Org{ + PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, + VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, + Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, + NestedAddress: NestedAddress{ + EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, + }, + } + if err := DB.Create(&org).Error; err != nil { + t.Errorf("failed to create org, got err: %v", err) + } + + tests := []struct { + name string + preloads map[string][]interface{} + expect Org + }{ + { + name: "address country", + preloads: map[string][]interface{}{"Address.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: org.Address, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "postal address country", + preloads: map[string][]interface{}{"PostalAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: org.PostalAddress, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "nested address country", + preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: org.NestedAddress, + }, + }, { + name: "associations", + preloads: map[string][]interface{}{ + clause.Associations: {}, + // clause.Associations won’t preload nested associations + "Address.Country": {}, + }, + expect: org, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := Org{} + tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) + for name, args := range test.preloads { + tx = tx.Preload(name, args...) + } + if err := tx.Find(&actual).Error; err != nil { + t.Errorf("failed to find org, got err: %v", err) + } + AssertEqual(t, actual, test.expect) + }) + } +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index 8730e547..b86bc3d6 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,6 +2,8 @@ package tests_test import ( "context" + "errors" + "sync" "testing" "time" @@ -88,3 +90,80 @@ func TestPreparedStmtFromTransaction(t *testing.T) { } tx2.Commit() } + +func TestPreparedStmtDeadlock(t *testing.T) { + tx, err := OpenTestConnection(&gorm.Config{}) + AssertEqual(t, err, nil) + + sqlDB, _ := tx.DB() + sqlDB.SetMaxOpenConns(1) + + tx = tx.Session(&gorm.Session{PrepareStmt: true}) + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + user := User{Name: "jinzhu"} + tx.Create(&user) + + var result User + tx.First(&result) + wg.Done() + }() + } + wg.Wait() + + conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + AssertEqual(t, ok, true) + AssertEqual(t, len(conn.Stmts), 2) + for _, stmt := range conn.Stmts { + if stmt == nil { + t.Fatalf("stmt cannot bee nil") + } + } + + AssertEqual(t, sqlDB.Stats().InUse, 0) +} + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +} diff --git a/tests/query_test.go b/tests/query_test.go index f66cf83a..c0259a14 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -2,6 +2,7 @@ package tests_test import ( "database/sql" + "database/sql/driver" "fmt" "reflect" "regexp" @@ -216,6 +217,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) @@ -257,7 +282,7 @@ func TestFindInBatches(t *testing.T) { totalBatch int ) - if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch if tx.RowsAffected != 2 { @@ -273,7 +298,7 @@ func TestFindInBatches(t *testing.T) { } if err := tx.Save(results).Error; err != nil { - t.Errorf("failed to save users, got error %v", err) + t.Fatalf("failed to save users, got error %v", err) } return nil @@ -384,6 +409,13 @@ func TestFindInBatchesWithError(t *testing.T) { if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } + + if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error != gorm.ErrPrimaryKeyRequired { + t.Fatal("expected errors to have occurred, but nothing happened") + } } func TestFillSmallerStruct(t *testing.T) { @@ -522,6 +554,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) { @@ -627,6 +664,18 @@ func TestOrWithAllFields(t *testing.T) { } } +type Int64 int64 + +func (v Int64) Value() (driver.Value, error) { + return v - 1, nil +} + +func (f *Int64) Scan(v interface{}) error { + y := v.(int64) + *f = Int64(y + 1) + return nil +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -654,6 +703,11 @@ func TestPluck(t *testing.T) { t.Errorf("got error when pluck id: %v", err) } + var ids2 []Int64 + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil { + t.Errorf("got error when pluck id: %v", err) + } + for idx, name := range names { if name != users[idx].Name { t.Errorf("Unexpected result on pluck name, got %+v", names) @@ -666,6 +720,12 @@ func TestPluck(t *testing.T) { } } + for idx, id := range ids2 { + if int(id) != int(users[idx].ID+1) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } + var times []time.Time if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { t.Errorf("got error when pluck time: %v", err) @@ -1063,12 +1123,12 @@ func TestSearchWithStruct(t *testing.T) { } result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) - if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } @@ -1258,3 +1318,113 @@ func TestQueryScannerWithSingleColumn(t *testing.T) { AssertEqual(t, result2.data, 20) } + +func TestQueryResetNullValue(t *testing.T) { + type QueryResetItem struct { + ID string `gorm:"type:varchar(5)"` + Name string + } + + type QueryResetNullValue struct { + ID int + Name string `gorm:"default:NULL"` + Flag bool `gorm:"default:NULL"` + Number1 int64 `gorm:"default:NULL"` + Number2 uint64 `gorm:"default:NULL"` + Number3 float64 `gorm:"default:NULL"` + Now *time.Time `gorm:"defalut:NULL"` + Item1Id string + Item1 *QueryResetItem `gorm:"references:ID"` + Item2Id string + Item2 *QueryResetItem `gorm:"references:ID"` + } + + DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) + DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) + + now := time.Now() + q1 := QueryResetNullValue{ + Name: "name", + Flag: true, + Number1: 100, + Number2: 200, + Number3: 300.1, + Now: &now, + Item1: &QueryResetItem{ + ID: "u_1_1", + Name: "item_1_1", + }, + Item2: &QueryResetItem{ + ID: "u_1_2", + Name: "item_1_2", + }, + } + + q2 := QueryResetNullValue{ + Item1: &QueryResetItem{ + ID: "u_2_1", + Name: "item_2_1", + }, + Item2: &QueryResetItem{ + ID: "u_2_2", + Name: "item_2_2", + }, + } + + var err error + err = DB.Create(&q1).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + err = DB.Create(&q2).Error + if err != nil { + t.Errorf("failed to create:%v", err) + } + + var qs []QueryResetNullValue + err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error + if err != nil { + t.Errorf("failed to find:%v", err) + } + + if len(qs) != 2 { + t.Fatalf("find count not equal:%d", len(qs)) + } + + AssertEqual(t, q1, qs[0]) + AssertEqual(t, q2, qs[1]) +} + +func TestQueryError(t *testing.T) { + type P struct{} + var p1 P + err := DB.Take(&p1, 1).Error + AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired) + + var p2 interface{} + + err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{ + Table: clause.CurrentTable, Name: clause.PrimaryKey, + }, Value: 1}).Scan(&p2).Error + AssertEqual(t, err, gorm.ErrModelValueRequired) +} + +func TestQueryScanToArray(t *testing.T) { + err := DB.Create(&User{Name: "testname1", Age: 10}).Error + if err != nil { + t.Fatal(err) + } + + users := [2]*User{{Name: "1"}, {Name: "2"}} + err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error + if err != nil { + t.Fatal(err) + } + if users[0] == nil || users[0].Name != "testname1" { + t.Error("users[0] not covere") + } + if users[1] != nil { + t.Error("users[1] should be empty") + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 425c0a29..6f2e9f54 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -214,4 +214,29 @@ func TestScanToEmbedded(t *testing.T) { if !addressMatched { t.Errorf("Failed, no address matched") } + + personDupField := Person{ID: person1.ID} + if err := DB.Select("people.id, people.*"). + First(&personDupField).Error; err != nil { + t.Errorf("Failed to run join query, got error: %v", err) + } + AssertEqual(t, person1, personDupField) + + user := User{ + Name: "TestScanToEmbedded_1", + Manager: &User{ + Name: "TestScanToEmbedded_1_m1", + Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"}, + }, + } + DB.Create(&user) + + type UserScan struct { + ID uint + Name string + ManagerID *uint + } + var user2 UserScan + err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error + AssertEqual(t, err, nil) } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 14121699..472434b4 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error { return errors.New("Too short") } - *data = b[3:] + *data = append((*data)[0:], b[3:]...) return nil } else if s, ok := value.(string); ok { - *data = []byte(s)[3:] + *data = []byte(s[3:]) return nil } diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea..84aeb990 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -72,3 +72,58 @@ func TestScopes(t *testing.T) { t.Errorf("select max(id)") } } + +func TestComplexScopes(t *testing.T) { + tests := []struct { + name string + queryFn func(tx *gorm.DB) *gorm.DB + expected string + }{ + { + name: "depth_1", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { + return d.Where(DB.Or("b = 2").Or("c = 3")) + }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, + }, { + name: "depth_1_pre_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Where("z = 0").Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { + return d.Or(DB.Where("b = 2").Or("c = 3")) + }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, + }, { + name: "depth_2", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, + func(d *gorm.DB) *gorm.DB { + return d. + Or(DB.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, + )). + Or("c = 3") + }, + func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, + ).Find(&Language{}) + }, + expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) + }) + } +} diff --git a/tests/serializer_test.go b/tests/serializer_test.go index ee14841a..f1b8a336 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -16,13 +16,40 @@ import ( type SerializerStruct struct { gorm.Model - Name []byte `gorm:"json"` - Roles Roles `gorm:"serializer:json"` - Contracts map[string]interface{} `gorm:"serializer:json"` - JobInfo Job `gorm:"type:bytes;serializer:gob"` - CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - UpdatedTime *int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type - EncryptedString EncryptedString + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString +} + +type SerializerPostgresStruct struct { + gorm.Model + Name []byte `gorm:"json"` + Roles Roles `gorm:"serializer:json"` + Roles2 *Roles `gorm:"serializer:json"` + Roles3 *Roles `gorm:"serializer:json;not null"` + Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` + CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type + CustomSerializerString string `gorm:"serializer:custom"` + EncryptedString EncryptedString +} + +func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" } + +func adaptorSerializerModel(s *SerializerStruct) interface{} { + if DB.Dialector.Name() == "postgres" { + sps := SerializerPostgresStruct(*s) + return &sps + } + return s } type Roles []string @@ -52,9 +79,34 @@ func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst re return "hello" + string(es), nil } +type CustomSerializer struct { + prefix []byte +} + +func NewCustomSerializer(prefix string) *CustomSerializer { + return &CustomSerializer{prefix: []byte(prefix)} +} + +func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { + switch value := dbValue.(type) { + case []byte: + err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix)) + case string: + err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix))) + default: + err = fmt.Errorf("unsupported data %#v", dbValue) + } + return err +} + +func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil +} + func TestSerializer(t *testing.T) { - DB.Migrator().DropTable(&SerializerStruct{}) - if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } @@ -74,12 +126,42 @@ func TestSerializer(t *testing.T) { Location: "Kenmawr", IsIntern: false, }, + CustomSerializerString: "world", } if err := DB.Create(&data).Error; err != nil { t.Fatalf("failed to create data, got error %v", err) } + var result SerializerStruct + if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } + + AssertEqual(t, result, data) + + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } +} + +func TestSerializerZeroValue(t *testing.T) { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := SerializerStruct{} + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("failed to create data, got error %v", err) + } + var result SerializerStruct if err := DB.First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) @@ -87,11 +169,19 @@ func TestSerializer(t *testing.T) { AssertEqual(t, result, data) + if err := DB.Model(&result).Update("roles", "").Error; err != nil { + t.Fatalf("failed to update data's roles, got error %v", err) + } + + if err := DB.First(&result, data.ID).Error; err != nil { + t.Fatalf("failed to query data, got error %v", err) + } } func TestSerializerAssignFirstOrCreate(t *testing.T) { - DB.Migrator().DropTable(&SerializerStruct{}) - if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil { + schema.RegisterSerializer("custom", NewCustomSerializer("hello")) + DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) + if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } @@ -109,6 +199,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { Location: "Shadyside", IsIntern: false, }, + CustomSerializerString: "world", } // first time insert record @@ -123,7 +214,7 @@ func TestSerializerAssignFirstOrCreate(t *testing.T) { } AssertEqual(t, result, out) - //update record + // update record data.Roles = append(data.Roles, "r3") data.JobInfo.Location = "Gates Hillman Complex" if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go index 9ac8da10..179ae426 100644 --- a/tests/soft_delete_test.go +++ b/tests/soft_delete_test.go @@ -7,6 +7,7 @@ import ( "regexp" "testing" + "github.com/jinzhu/now" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -39,6 +40,11 @@ func TestSoftDelete(t *testing.T) { t.Fatalf("invalid sql generated, got %v", sql) } + sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) { + t.Errorf("Table with escape character, got %v", sql) + } + if DB.First(&User{}, "name = ?", user.Name).Error == nil { t.Errorf("Can't find a soft deleted record") } @@ -93,3 +99,71 @@ func TestDeletedAtOneOr(t *testing.T) { t.Fatalf("invalid sql generated, got %v", actualSQL) } } + +func TestSoftDeleteZeroValue(t *testing.T) { + type SoftDeleteBook struct { + ID uint + Name string + Pages uint + DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"` + } + DB.Migrator().DropTable(&SoftDeleteBook{}) + if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil { + t.Fatalf("failed to auto migrate soft delete table") + } + + book := SoftDeleteBook{Name: "jinzhu", Pages: 10} + DB.Save(&book) + + var count int64 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + var pages uint + if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages) + } + + if err := DB.Delete(&book).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + zeroTime, _ := now.Parse("1970-01-01 00:00:01") + if book.DeletedAt.Time.Equal(zeroTime) { + t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt) + } + + if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + pages = 0 + if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 { + t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err) + } + + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + pages = 0 + if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages) + } + + DB.Unscoped().Delete(&book) + if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index a9b920dc..0c204db4 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -29,7 +29,7 @@ func TestRow(t *testing.T) { } table := "gorm.users" - if DB.Dialector.Name() != "mysql" { + if DB.Dialector.Name() != "mysql" || isTiDB() { table = "users" // other databases doesn't support select with `database.table` } @@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } - date, _ := time.Parse("2006-01-02", "2021-10-18") + date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) { sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) }) - assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) + assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) // last and unscoped sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { @@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) { if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } + + // UpdateColumns + sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, + }) + }) + assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. diff --git a/tests/table_test.go b/tests/table_test.go index 0289b7b8..0d44a15b 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -2,9 +2,13 @@ package tests_test import ( "regexp" + "sync" "testing" + "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) @@ -145,3 +149,113 @@ func TestTableWithAllFields(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +type UserWithTableNamer struct { + gorm.Model + Name string +} + +func (UserWithTableNamer) TableName(namer schema.Namer) string { + return namer.TableName("user") +} + +func TestTableWithNamer(t *testing.T) { + db, _ := gorm.Open(tests.DummyDialector{}, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: "t_", + }, + }) + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) + }) + + if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { + t.Errorf("Table with namer, got %v", sql) + } +} + +func TestPostgresTableWithIdentifierLength(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type LongString struct { + ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"` + } + + t.Run("default", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("naming strategy", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{}, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("namer", func(t *testing.T) { + uname := "custom_unique_name" + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: mockUniqueNamingStrategy{ + UName: uname, + }, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if key != uname { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) +} + +type mockUniqueNamingStrategy struct { + UName string + schema.NamingStrategy +} + +func (a mockUniqueNamingStrategy) UniqueName(table, column string) string { + return a.UName +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 5b9bae97..ee9e7675 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,6 +1,6 @@ #!/bin/bash -e -dialects=("sqlite" "mysql" "postgres" "sqlserver") +dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. diff --git a/tests/tests_test.go b/tests/tests_test.go index 08f4f193..a127734e 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -17,10 +17,16 @@ import ( ) var DB *gorm.DB +var ( + mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" +) func init() { var err error - if DB, err = OpenTestConnection(); err != nil { + if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { @@ -37,30 +43,27 @@ func init() { } RunMigrations() - if DB.Dialector.Name() == "sqlite" { - DB.Exec("PRAGMA foreign_keys = ON") - } } } -func OpenTestConnection() (db *gorm.DB, err error) { +func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": log.Println("testing mysql...") if dbDSN == "" { - dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + dbDSN = mysqlDSN } - db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(mysql.Open(dbDSN), cfg) case "postgres": log.Println("testing postgres...") if dbDSN == "" { - dbDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + dbDSN = postgresDSN } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, - }), &gorm.Config{}) + }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 @@ -72,12 +75,21 @@ func OpenTestConnection() (db *gorm.DB, err error) { // GO log.Println("testing sqlserver...") if dbDSN == "" { - dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + dbDSN = sqlserverDSN } - db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) + case "tidb": + log.Println("testing tidb...") + if dbDSN == "" { + dbDSN = tidbDSN + } + db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") - db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) + if err == nil { + db.Exec("PRAGMA foreign_keys = ON") + } } if err != nil { @@ -95,7 +107,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/tests/tracer_test.go b/tests/tracer_test.go new file mode 100644 index 00000000..3e9a4052 --- /dev/null +++ b/tests/tracer_test.go @@ -0,0 +1,34 @@ +package tests_test + +import ( + "context" + "time" + + "gorm.io/gorm/logger" +) + +type Tracer struct { + Logger logger.Interface + Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) +} + +func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { + return S.Logger.LogMode(level) +} + +func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { + S.Logger.Info(ctx, s, i...) +} + +func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { + S.Logger.Warn(ctx, s, i...) +} + +func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { + S.Logger.Error(ctx, s, i...) +} + +func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + S.Logger.Trace(ctx, begin, fc, err) + S.Test(ctx, begin, fc, err) +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 4e4b6149..126ccb23 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -57,6 +57,19 @@ func TestTransaction(t *testing.T) { if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } + + t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) { + // enable prepare statement + tx3 := DB.Session(&gorm.Session{PrepareStmt: true}) + if err := tx3.Transaction(func(tx4 *gorm.DB) error { + // nested transaction + return tx4.Transaction(func(tx5 *gorm.DB) error { + return tx5.First(&User{}, "name = ?", "transaction-2").Error + }) + }); err != nil { + t.Fatalf("prepare statement and nested transcation coexist" + err.Error()) + } + }) } func TestCancelTransaction(t *testing.T) { @@ -102,7 +115,7 @@ func TestTransactionWithBlock(t *testing.T) { return errors.New("the error message") }) - if err.Error() != "the error message" { + if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } @@ -348,7 +361,7 @@ func TestDisabledNestedTransaction(t *testing.T) { } func TestTransactionOnClosedConn(t *testing.T) { - DB, err := OpenTestConnection() + DB, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } @@ -367,3 +380,33 @@ func TestTransactionOnClosedConn(t *testing.T) { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } + +func TestTransactionWithHooks(t *testing.T) { + user := GetUser("tTestTransactionWithHooks", Config{Account: true}) + DB.Create(&user) + + var err error + err = DB.Transaction(func(tx *gorm.DB) error { + return tx.Model(&User{}).Limit(1).Transaction(func(tx2 *gorm.DB) error { + return tx2.Scan(&User{}).Error + }) + }) + + if err != nil { + t.Error(err) + } + + // method with hooks + err = DB.Transaction(func(tx1 *gorm.DB) error { + // callMethod do + tx2 := tx1.Find(&User{}).Session(&gorm.Session{NewDB: true}) + // trx in hooks + return tx2.Transaction(func(tx3 *gorm.DB) error { + return tx3.Where("user_id", user.ID).Delete(&Account{}).Error + }) + }) + + if err != nil { + t.Error(err) + } +} diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 8fe0f289..4e94cfd5 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -41,4 +41,19 @@ func TestUpdateBelongsTo(t *testing.T) { var user4 User DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) + + user.Company.Name += "new2" + user.Manager.Name += "new2" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user5 User + DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID) + if user5.Manager.Name != user4.Manager.Name { + t.Errorf("should not update user's manager") + } else { + user.Manager.Name = user4.Manager.Name + } + CheckUser(t, user, user5) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index c926fbcf..40af6ae7 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -90,8 +90,9 @@ func TestUpdateHasOne(t *testing.T) { t.Run("Restriction", func(t *testing.T) { type CustomizeAccount struct { gorm.Model - UserID sql.NullInt64 - Number string `gorm:"<-:create"` + UserID sql.NullInt64 + Number string `gorm:"<-:create"` + Number2 string } type CustomizeUser struct { @@ -114,7 +115,8 @@ func TestUpdateHasOne(t *testing.T) { cusUser := CustomizeUser{ Name: "update-has-one-associations", Account: CustomizeAccount{ - Number: number, + Number: number, + Number2: number, }, } @@ -122,6 +124,7 @@ func TestUpdateHasOne(t *testing.T) { t.Fatalf("errors happened when create: %v", err) } cusUser.Account.Number += "-update" + cusUser.Account.Number2 += "-update" if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } @@ -129,5 +132,6 @@ func TestUpdateHasOne(t *testing.T) { var account2 CustomizeAccount DB.Find(&account2, "user_id = ?", cusUser.ID) AssertEqual(t, account2.Number, number) + AssertEqual(t, account2.Number2, cusUser.Account.Number2) }) } diff --git a/tests/update_test.go b/tests/update_test.go index 41ea5d27..9eb9dbfc 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -122,6 +122,14 @@ func TestUpdate(t *testing.T) { } else { CheckUser(t, result4, *user) } + + if rowsAffected := DB.Model([]User{result4}).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + if rowsAffected := DB.Model(users).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 3 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } } func TestUpdates(t *testing.T) { @@ -200,13 +208,17 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) - DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) if users[1].Name != "update_column_02_newnew" { t.Errorf("user 2's name should be updated, but got %v", users[1].Name) } + if users[1].Age != 19 { + t.Errorf("user 2's name should be updated, but got %v", users[1].Age) + } + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) var user3 User DB.First(&user3, users[1].ID) @@ -299,6 +311,8 @@ func TestSelectWithUpdate(t *testing.T) { if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) } + + AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age") } func TestSelectWithUpdateWithMap(t *testing.T) { @@ -600,6 +614,25 @@ func TestUpdateFromSubQuery(t *testing.T) { } } +func TestIdempotentSave(t *testing.T) { + create := Company{ + Name: "company_idempotent", + } + DB.Create(&create) + + var company Company + if err := DB.Find(&company, "id = ?", create.ID).Error; err != nil { + t.Fatalf("failed to find created company, got err: %v", err) + } + + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } + if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { + t.Errorf("failed to save company, got err: %v", err) + } +} + func TestSave(t *testing.T) { user := *GetUser("save", Config{}) DB.Create(&user) @@ -732,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) { } } -// only sqlite, postgres support returning +// only sqlite, postgres, sqlserver support returning func TestUpdateReturning(t *testing.T) { - if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { + if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" { return } @@ -763,3 +796,138 @@ func TestUpdateReturning(t *testing.T) { t.Errorf("failed to return updated age column") } } + +func TestUpdateWithDiffSchema(t *testing.T) { + user := GetUser("update-diff-schema-1", Config{}) + DB.Create(&user) + + type UserTemp struct { + Name string + } + + err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error + AssertEqual(t, err, nil) + AssertEqual(t, "update-diff-schema-2", user.Name) +} + +type TokenOwner struct { + ID int + Name string + Token Token `gorm:"foreignKey:UserID"` +} + +func (t *TokenOwner) BeforeSave(tx *gorm.DB) error { + t.Name += "_name" + return nil +} + +type Token struct { + UserID int `gorm:"primary_key"` + Content string `gorm:"type:varchar(100)"` +} + +func (t *Token) BeforeSave(tx *gorm.DB) error { + t.Content += "_encrypted" + return nil +} + +func TestSaveWithHooks(t *testing.T) { + DB.Migrator().DropTable(&Token{}, &TokenOwner{}) + DB.AutoMigrate(&Token{}, &TokenOwner{}) + + saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { + var newOwner TokenOwner + if err := DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { + return err + } + if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } + return &newOwner, nil + } + + owner := TokenOwner{ + Name: "user", + Token: Token{Content: "token"}, + } + o1, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o1.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name) + } + if o1.Token.Content != "token_encrypted" { + t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content) + } + + owner = TokenOwner{ + ID: owner.ID, + Name: "user", + Token: Token{Content: "token2"}, + } + o2, err := saveTokenOwner(&owner) + if err != nil { + t.Errorf("failed to save token owner, got error: %v", err) + } + if o2.Name != "user_name" { + t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name) + } + if o2.Token.Content != "token2_encrypted" { + t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) + } +} + +// only postgres, sqlserver, sqlite support update from +func TestUpdateFrom(t *testing.T) { + if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" { + return + } + + users := []*User{ + GetUser("update-from-1", Config{Account: true}), + GetUser("update-from-2", Config{Account: true}), + GetUser("update-from-3", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if users[0].ID == 0 { + t.Fatalf("user's primary value should not zero, %v", users[0].ID) + } else if users[0].UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", users[0].UpdatedAt) + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number = ? AND accounts.deleted_at IS NULL", users[0].Account.Number).Update("name", "franco").RowsAffected; rowsAffected != 1 { + t.Errorf("should only update one record, but got %v", rowsAffected) + } + + var result User + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } else if result.UpdatedAt.UnixNano() == users[0].UpdatedAt.UnixNano() { + t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, users[0].UpdatedAt) + } else if result.Name != "franco" { + t.Errorf("user's name should be updated") + } + + if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number IN ? AND accounts.deleted_at IS NULL", []string{users[0].Account.Number, users[1].Account.Number}).Update("name", gorm.Expr("accounts.number")).RowsAffected; rowsAffected != 2 { + t.Errorf("should update two records, but got %v", rowsAffected) + } + + var results []User + if err := DB.Preload("Account").Find(&results, []uint{users[0].ID, users[1].ID}).Error; err != nil { + t.Errorf("Not error should happen when finding users, but got %v", err) + } + + for _, user := range results { + if user.Name != user.Account.Number { + t.Errorf("user's name should be equal to the account's number %v, but got %v", user.Account.Number, user.Name) + } + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f90c4518..e84dc14a 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go index 2990c20f..a2d9c33d 100644 --- a/utils/tests/dummy_dialecter.go +++ b/utils/tests/dummy_dialecter.go @@ -2,18 +2,28 @@ package tests import ( "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) -type DummyDialector struct{} +type DummyDialector struct { + TranslatedErr error +} func (DummyDialector) Name() string { return "dummy" } -func (DummyDialector) Initialize(*gorm.DB) error { +func (DummyDialector) Initialize(db *gorm.DB) error { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + return nil } @@ -84,3 +94,7 @@ func (DummyDialector) Explain(sql string, vars ...interface{}) string { func (DummyDialector) DataTypeOf(*schema.Field) string { return "" } + +func (d DummyDialector) Translate(err error) error { + return d.TranslatedErr +} diff --git a/utils/tests/models.go b/utils/tests/models.go index 22e8e659..f9f4f50e 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,7 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) -// NamedPet is a reference to a Named `Pets` (has many) +// NamedPet is a reference to a named `Pet` (has one) type User struct { gorm.Model Name string @@ -20,7 +20,8 @@ type User struct { Account Account Pets []*Pet NamedPet *Pet - Toys []Toy `gorm:"polymorphic:Owner"` + Toys []Toy `gorm:"polymorphic:Owner"` + Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` CompanyID *int Company Company ManagerID *uint @@ -51,6 +52,13 @@ type Toy struct { OwnerType string } +type Tools struct { + gorm.Model + Name string + CustomID string + Type string +} + type Company struct { ID int Name string @@ -64,8 +72,8 @@ type Language struct { type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` - AmountOff uint32 `gorm:"amount_off"` - PercentOff float32 `gorm:"percent_off"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct { diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 661d727f..49d01f2e 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -13,8 +13,14 @@ import ( func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { - got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() - expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + rv := reflect.Indirect(reflect.ValueOf(r)) + ev := reflect.Indirect(reflect.ValueOf(e)) + if rv.IsValid() != ev.IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) + return + } + got := rv.FieldByName(name).Interface() + expect := ev.FieldByName(name).Interface() t.Run(name, func(t *testing.T) { AssertEqual(t, got, expect) }) diff --git a/utils/utils.go b/utils/utils.go index 296917b9..347a331f 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -3,8 +3,8 @@ package utils import ( "database/sql/driver" "fmt" + "path/filepath" "reflect" - "regexp" "runtime" "strconv" "strings" @@ -16,7 +16,18 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems - gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") + gormSourceDir = sourceDir(file) +} + +func sourceDir(file string) string { + dir := filepath.Dir(file) + dir = filepath.Dir(dir) + + s := filepath.Dir(dir) + if filepath.Base(s) != "gorm.io" { + s = dir + } + return filepath.ToSlash(s) + "/" } // FileWithLineNum return the file name and line number of the current file @@ -24,7 +35,8 @@ func FileWithLineNum() string { // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) && + !strings.HasSuffix(file, ".gen.go") { return file + ":" + strconv.FormatInt(int64(line), 10) } } @@ -62,7 +74,11 @@ func ToStringKey(values ...interface{}) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + results[idx] = "nil" + vv := reflect.ValueOf(v) + if vv.IsValid() && !vv.IsZero() { + results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) + } } } @@ -78,19 +94,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } - - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } - - return reflect.DeepEqual(src, dst) +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true } - return true + if x == nil || y == nil { + return false + } + + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } + + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() + } + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { @@ -120,3 +145,20 @@ func ToString(value interface{}) string { } return "" } + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 5737c511..8ff42af8 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,8 +1,13 @@ package utils import ( + "database/sql" + "database/sql/driver" + "errors" + "math" "strings" "testing" + "time" ) func TestIsValidDBNameChar(t *testing.T) { @@ -12,3 +17,124 @@ func TestIsValidDBNameChar(t *testing.T) { } } } + +func TestCheckTruth(t *testing.T) { + checkTruthTests := []struct { + v string + out bool + }{ + {"123", true}, + {"true", true}, + {"", false}, + {"false", false}, + {"False", false}, + {"FALSE", false}, + {"\u0046alse", false}, + } + + for _, test := range checkTruthTests { + t.Run(test.v, func(t *testing.T) { + if out := CheckTruth(test.v); out != test.out { + t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) + } + }) + } +} + +func TestToStringKey(t *testing.T) { + cases := []struct { + values []interface{} + key string + }{ + {[]interface{}{"a"}, "a"}, + {[]interface{}{1, 2, 3}, "1_2_3"}, + {[]interface{}{1, nil, 3}, "1_nil_3"}, + {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", nil, "3"}}, "[1 3]"}, + } + for _, c := range cases { + if key := ToStringKey(c.values...); key != c.key { + t.Errorf("%v: expected %v, got %v", c.values, c.key, key) + } + } +} + +func TestContains(t *testing.T) { + containsTests := []struct { + name string + elems []string + elem string + out bool + }{ + {"exists", []string{"1", "2", "3"}, "1", true}, + {"not exists", []string{"1", "2", "3"}, "4", false}, + } + for _, test := range containsTests { + t.Run(test.name, func(t *testing.T) { + if out := Contains(test.elems, test.elem); test.out != out { + t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) + } + }) + } +} + +type ModifyAt sql.NullTime + +// Value return a Unix time. +func (n ModifyAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time.Unix(), nil +} + +func TestAssertEqual(t *testing.T) { + now := time.Now() + assertEqualTests := []struct { + name string + src, dst interface{} + out bool + }{ + {"error equal", errors.New("1"), errors.New("1"), true}, + {"error not equal", errors.New("1"), errors.New("2"), false}, + {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, + {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, + {"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, + } + for _, test := range assertEqualTests { + t.Run(test.name, func(t *testing.T) { + if out := AssertEqual(test.src, test.dst); test.out != out { + t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) + } + }) + } +} + +func TestToString(t *testing.T) { + tests := []struct { + name string + in interface{} + out string + }{ + {"int", math.MaxInt64, "9223372036854775807"}, + {"int8", int8(math.MaxInt8), "127"}, + {"int16", int16(math.MaxInt16), "32767"}, + {"int32", int32(math.MaxInt32), "2147483647"}, + {"int64", int64(math.MaxInt64), "9223372036854775807"}, + {"uint", uint(math.MaxUint64), "18446744073709551615"}, + {"uint8", uint8(math.MaxUint8), "255"}, + {"uint16", uint16(math.MaxUint16), "65535"}, + {"uint32", uint32(math.MaxUint32), "4294967295"}, + {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, + {"string", "abc", "abc"}, + {"other", true, ""}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if out := ToString(test.in); test.out != out { + t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) + } + }) + } +} diff --git a/utils/utils_unix_test.go b/utils/utils_unix_test.go new file mode 100644 index 00000000..450cbe2a --- /dev/null +++ b/utils/utils_unix_test.go @@ -0,0 +1,38 @@ +//go:build unix +// +build unix + +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/Users/name/go/pkg/mod/gorm.io/", + }, + { + file: "/go/work/proj/gorm/utils/utils.go", + want: "/go/work/proj/gorm/", + }, + { + file: "/go/work/proj/gorm_alias/utils/utils.go", + want: "/go/work/proj/gorm_alias/", + }, + { + file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", + want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +} diff --git a/utils/utils_windows_test.go b/utils/utils_windows_test.go new file mode 100644 index 00000000..8b1c519d --- /dev/null +++ b/utils/utils_windows_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "testing" +) + +func TestSourceDir(t *testing.T) { + cases := []struct { + file string + want string + }{ + { + file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/Users/name/go/pkg/mod/gorm.io/`, + }, + { + file: `C:/go/work/proj/gorm/utils/utils.go`, + want: `C:/go/work/proj/gorm/`, + }, + { + file: `C:/go/work/proj/gorm_alias/utils/utils.go`, + want: `C:/go/work/proj/gorm_alias/`, + }, + { + file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, + want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, + }, + } + for _, c := range cases { + s := sourceDir(c.file) + if s != c.want { + t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) + } + } +}