diff --git a/.codeclimate.yml b/.codeclimate.yml deleted file mode 100644 index 51aba50c..00000000 --- a/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..2e7a32d9 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,5 @@ +# These are supported funding model platforms + +github: [jinzhu] +patreon: jinzhu +open_collective: gorm diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 8b4f03b7..00000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,54 +0,0 @@ -Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`. - -DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm, - -Please answer these questions before submitting your issue. Thanks! - - - -### What version of Go are you using (`go version`)? - - -### Which database and its version are you using? - - -### What did you do? - -Please provide a complete runnable program to reproduce your issue. - -```go -package main - -import ( - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" -) - -var db *gorm.DB - -func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // Please use below username, password as your database's account for the script. - // db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") - // db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True") - // db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") - if err != nil { - panic(err) - } - db.LogMode(true) -} - -func main() { - // your code here - - if /* failure condition */ { - fmt.Println("failed") - } else { - fmt.Println("success") - } -} -``` diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 0ee0d73b..00000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,14 +0,0 @@ -Make sure these boxes checked before submitting your pull request. - -- [] Do only one thing -- [] No API-breaking changes -- [] New code/logic commented & tested -- [] Write good commit message, try to squash your commits into a single one -- [] Run `./build.sh` in `gh-pages` branch for document changes - -For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it. - -Thank you. - - -### What did this pull request do? diff --git a/.github/labels.json b/.github/labels.json new file mode 100644 index 00000000..5c7eb7d1 --- /dev/null +++ b/.github/labels.json @@ -0,0 +1,166 @@ +{ + "labels": { + "critical": { + "name": "type:critical", + "colour": "#E84137", + "description": "critical questions" + }, + "question": { + "name": "type:question", + "colour": "#EDEDED", + "description": "general questions" + }, + "feature": { + "name": "type:feature_request", + "colour": "#43952A", + "description": "feature request" + }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, + "with_playground": { + "name": "type:with reproduction steps", + "colour": "#00ff00", + "description": "with reproduction steps" + }, + "without_playground": { + "name": "type:missing reproduction steps", + "colour": "#CF2E1F", + "description": "missing reproduction steps" + }, + "has_pr": { + "name": "type:has pull request", + "colour": "#43952A", + "description": "has pull request" + }, + "not_tested": { + "name": "type:not tested", + "colour": "#CF2E1F", + "description": "not tested" + }, + "tested": { + "name": "type:tested", + "colour": "#00ff00", + "description": "tested" + }, + "breaking_change": { + "name": "type:breaking change", + "colour": "#CF2E1F", + "description": "breaking change" + } + }, + "issue": { + "with_playground": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s" + } + ] + }, + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "question": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/question/i" + }, + { + "type": "descriptionMatches", + "pattern": "/question/i" + } + ] + }, + "feature": { + "requires": 1, + "conditions": [ + { + "type": "titleMatches", + "pattern": "/feature/i" + }, + { + "type": "descriptionMatches", + "pattern": "/Describe the feature/i" + } + ] + }, + "without_playground": { + "requires": 6, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s" + }, + { + "type": "titleMatches", + "pattern": "/^((?!question).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!question).)*$/is" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!Describe the feature).)*$/is" + }, + { + "type": "titleMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + }, + { + "type": "descriptionMatches", + "pattern": "/^((?!critical|urgent).)*$/s" + } + ] + } + }, + "pr": { + "critical": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/(critical|urgent)/i" + }, + { + "type": "titleMatches", + "pattern": "/(critical|urgent)/i" + } + ] + }, + "not_tested": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Tested/" + } + ] + }, + "breaking_change": { + "requires": 1, + "conditions": [ + { + "type": "descriptionMatches", + "pattern": "/\\[\\] Non breaking API changes/" + } + ] + } + } +} diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 00000000..5b0bd981 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,22 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..bc1add53 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,19 @@ +name: "Issue Labeler" +on: + issues: + types: [opened, edited, reopened] + pull_request: + types: [opened, edited, reopened] + +jobs: + triage: + runs-on: ubuntu-latest + name: Label issues and pull requests + steps: + - name: check out + uses: actions/checkout@v2 + + - name: labeler + uses: jinzhu/super-labeler-action@develop + with: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml new file mode 100644 index 00000000..ea3207d6 --- /dev/null +++ b/.github/workflows/missing_playground.yml @@ -0,0 +1,21 @@ +name: "Close Missing Playground issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:missing reproduction steps" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml new file mode 100644 index 00000000..4511c378 --- /dev/null +++ b/.github/workflows/reviewdog.yml @@ -0,0 +1,11 @@ +name: reviewdog +on: [pull_request] +jobs: + golangci-lint: + name: runner / golangci-lint + runs-on: ubuntu-latest + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + - name: golangci-lint + uses: reviewdog/action-golangci-lint@v1 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..f9c1bece --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,22 @@ +name: "Stale" +on: + schedule: + - cron: "0 2 * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" + days-before-stale: 60 + days-before-close: 30 + stale-issue-label: "status:stale" + exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' + stale-pr-label: 'status:stale' + exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..370417fc --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,166 @@ +name: tests + +on: + push: + branches-ignore: + - 'gh-pages' + pull_request: + branches-ignore: + - 'gh-pages' + +jobs: + # Label of the container job + sqlite: + strategy: + matrix: + go: ['1.16', '1.15', '1.14'] + platform: [ubuntu-latest] # can not run in windows OS + runs-on: ${{ matrix.platform }} + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: go mod package cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GORM_DIALECT=sqlite ./tests/tests_all.sh + + mysql: + strategy: + matrix: + dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] + go: ['1.16', '1.15', '1.14'] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + + services: + mysql: + image: ${{ matrix.dbversion }} + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + ports: + - 9910:3306 + options: >- + --health-cmd "mysqladmin ping -ugorm -pgorm" + --health-interval 10s + --health-start-period 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + + - name: go mod package cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh + + postgres: + strategy: + matrix: + dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] + go: ['1.16', '1.15', '1.14'] + platform: [ubuntu-latest] # can not run in macOS and Windows + runs-on: ${{ matrix.platform }} + + services: + postgres: + image: ${{ matrix.dbversion }} + env: + POSTGRES_PASSWORD: gorm + POSTGRES_USER: gorm + POSTGRES_DB: gorm + TZ: Asia/Shanghai + ports: + - 9920:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: go mod package cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh + + sqlserver: + strategy: + matrix: + go: ['1.16', '1.15', '1.14'] + platform: [ubuntu-latest] # can not run test in macOS and windows + runs-on: ${{ matrix.platform }} + + services: + mssql: + image: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + ports: + - 9930:1433 + options: >- + --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" + --health-start-period 10s + --health-interval 10s + --health-timeout 5s + --health-retries 10 + + steps: + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: go mod package cache + uses: actions/cache@v2 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} + + - name: Tests + run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh diff --git a/.gitignore b/.gitignore index 01dc5ce0..e1b9ecea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ +TODO* documents +coverage.txt _book +.idea diff --git a/README.md b/README.md index 44eb4a69..a3eabe39 100644 --- a/README.md +++ b/README.md @@ -2,49 +2,41 @@ The fantastic ORM library for Golang, aims to be developer friendly. -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) +[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) +[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) +[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm) +[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm) +[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) +[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) ## Overview -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions +* Full-Featured ORM +* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) +* Hooks (Before/After Create/Save/Update/Delete/Find) +* Eager loading with `Preload`, `Joins` +* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point +* Context, Prepared Statement Mode, DryRun Mode +* Batch Insert, FindInBatches, Find To Map +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key -* SQL Builder * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly ## Getting Started -* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm) +* GORM Guides [https://gorm.io](https://gorm.io) -## Upgrading To V1.0 +## Contributing -* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html) - -## Supporting the project - -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) - -## Author - -**jinzhu** - -* -* -* - -## Contributors - -https://github.com/jinzhu/gorm/graphs/contributors +[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) ## License -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). +© Jinzhu, 2013~time.Now + +Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) diff --git a/association.go b/association.go index 3d522ccc..572f1526 100644 --- a/association.go +++ b/association.go @@ -1,375 +1,513 @@ package gorm import ( - "errors" "fmt" "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. type Association struct { - Error error - scope *Scope - column string - field *Field + DB *DB + Relationship *schema.Relationship + Error error } -// Find find out all related associations -func (association *Association) Find(value interface{}) *Association { - association.scope.related(value, association.column) - return association.setErr(association.scope.db.Error) -} +func (db *DB) Association(column string) *Association { + association := &Association{DB: db} + table := db.Statement.Table -// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to -func (association *Association) Append(values ...interface{}) *Association { - if association.Error != nil { - return association - } + if err := db.Statement.Parse(db.Statement.Model); err == nil { + db.Statement.Table = table + association.Relationship = db.Statement.Schema.Relationships.Relations[column] - if relationship := association.field.Relationship; relationship.Kind == "has_one" { - return association.Replace(values...) - } - return association.saveAssociations(values...) -} + if association.Relationship == nil { + association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + } -// Replace replace current associations with new one -func (association *Association) Replace(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - // Append new values - association.field.Set(reflect.Zero(association.field.Field.Type())) - association.saveAssociations(values...) - - // Belongs To - if relationship.Kind == "belongs_to" { - // Set foreign key to be null when clearing value (length equals 0) - if len(values) == 0 { - // Set foreign key to be nil - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error) + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } } else { - // Polymorphic Relations - if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - - // Delete Relations except new created - if len(values) > 0 { - var associationForeignFieldNames, associationForeignDBNames []string - if relationship.Kind == "many_to_many" { - // if many to many relations, get association fields name from association foreign keys - associationScope := scope.New(reflect.New(field.Type()).Interface()) - for idx, dbName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(dbName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) - } - } - } else { - // If has one/many relations, use primary keys - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - associationForeignDBNames = append(associationForeignDBNames, field.DBName) - } - } - - newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) - - if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) - } - } - - if relationship.Kind == "many_to_many" { - // if many to many relations, delete related relations from join table - var sourceForeignFieldNames []string - - for _, dbName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name) - } - } - - if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { - newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) - var foreignKeyMap = map[string]interface{}{} - for idx, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } + association.Error = err } + return association } -// Delete remove relationship between source & passed arguments, but won't delete those arguments -func (association *Association) Delete(values ...interface{}) *Association { - if association.Error != nil { - return association - } - - var ( - relationship = association.field.Relationship - scope = association.scope - field = association.field.Field - newDB = scope.NewDB() - ) - - if len(values) == 0 { - return association - } - - var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string - for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { - deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name) - deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName) - } - - deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...) - - if relationship.Kind == "many_to_many" { - // source value's foreign keys - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - // get association's foreign fields name - var associationScope = scope.New(reflect.New(field.Type()).Interface()) - var associationForeignFieldNames []string - for _, associationDBName := range relationship.AssociationForeignFieldNames { - if field, ok := associationScope.FieldByName(associationDBName); ok { - associationForeignFieldNames = append(associationForeignFieldNames, field.Name) - } - } - - // association value's foreign keys - deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...) - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) - newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) - } else { - var foreignKeyMap = map[string]interface{}{} - for _, foreignKey := range relationship.ForeignDBNames { - foreignKeyMap[foreignKey] = nil - } - - if relationship.Kind == "belongs_to" { - // find with deleting relation's foreign keys - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // set foreign key to be null if there are some records affected - modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface() - if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil { - if results.RowsAffected > 0 { - scope.updatedAttrsWithValues(foreignKeyMap) - } - } else { - association.setErr(results.Error) - } - } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { - // find all relations - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - - // only include those deleting relations - newDB = newDB.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)), - toQueryValues(deletingPrimaryKeys)..., - ) - - // set matched relation's foreign key to be null - fieldValue := reflect.New(association.field.Field.Type()).Interface() - association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error) - } - } - - // Remove deleted records from source's field +func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - if field.Kind() == reflect.Slice { - leftValues := reflect.Zero(field.Type()) + association.Error = association.buildCondition().Find(out, conds...).Error + } + return association.Error +} - for i := 0; i < field.Len(); i++ { - reflectValue := field.Index(i) - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0] - var isDeleted = false - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - isDeleted = true - break +func (association *Association) Append(values ...interface{}) error { + if association.Error == nil { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + if len(values) > 0 { + association.Error = association.Replace(values...) + } + default: + association.saveAssociation( /*clear*/ false, values...) + } + } + + return association.Error +} + +func (association *Association) Replace(values ...interface{}) error { + if association.Error == nil { + // save associations + if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { + return association.Error + } + + // set old associations's foreign key to null + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + switch rel.Type { + case schema.BelongsTo: + if len(values) == 0 { + updateMap := map[string]interface{}{} + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } + case reflect.Struct: + association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } - if !isDeleted { - leftValues = reflect.Append(leftValues, reflectValue) + + for _, ref := range rel.References { + updateMap[ref.ForeignKey.DBName] = nil + } + + association.Error = association.DB.UpdateColumns(updateMap).Error + } + case schema.HasOne, schema.HasMany: + var ( + primaryFields []*schema.Field + foreignKeys []string + updateMap = map[string]interface{}{} + relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) + modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { + if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { + tx.Not(clause.IN{Column: column, Values: values}) } } - association.field.Set(leftValues) - } else if field.Kind() == reflect.Struct { - primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0] - for _, pk := range deletingPrimaryKeys { - if equalAsString(primaryKey, pk) { - association.field.Set(reflect.Zero(field.Type())) - break + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateMap[ref.ForeignKey.DBName] = nil + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } + + if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { + tx.Where(clause.IN{Column: column, Values: values}) + } else { + return ErrPrimaryKeyRequired + } + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { + tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) + } + + association.Error = tx.Delete(modelValue).Error + } + } + return association.Error +} + +func (association *Association) Delete(values ...interface{}) error { + if association.Error == nil { + var ( + reflectValue = association.DB.Statement.ReflectValue + rel = association.Relationship + primaryFields []*schema.Field + foreignKeys []string + updateAttrs = map[string]interface{}{} + conds []clause.Expression + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + primaryFields = append(primaryFields, ref.PrimaryKey) + foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) + updateAttrs[ref.ForeignKey.DBName] = nil + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + switch rel.Type { + case schema.BelongsTo: + tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) + relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + case schema.HasOne, schema.HasMany: + tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + case schema.Many2Many: + var ( + primaryFields, relPrimaryFields []*schema.Field + joinPrimaryKeys, joinRelPrimaryKeys []string + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() + ) + + for _, ref := range rel.References { + if ref.PrimaryValue == "" { + if ref.OwnPrimaryKey { + primaryFields = append(primaryFields, ref.PrimaryKey) + joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) + } else { + relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) + joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) + } + } else { + conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } + } + + _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) + pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) + conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) + + _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) + relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) + conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) + + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error + } + + if association.Error == nil { + // clean up deleted values's foreign key + relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) + + cleanUpDeletedRelations := func(data reflect.Value) { + if _, zero := rel.Field.ValueOf(data); !zero { + fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) + + switch fieldValue.Kind() { + case reflect.Slice, reflect.Array: + validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) + for i := 0; i < fieldValue.Len(); i++ { + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { + validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) + } + } + + association.Error = rel.Field.Set(data, validFieldValues.Interface()) + case reflect.Struct: + for idx, field := range rel.FieldSchema.PrimaryFields { + primaryValues[idx], _ = field.ValueOf(fieldValue) + } + + if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { + if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { + break + } + + if rel.JoinTable == nil { + for _, ref := range rel.References { + if ref.OwnPrimaryKey || ref.PrimaryValue != "" { + association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } else { + association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) + } + case reflect.Struct: + cleanUpDeletedRelations(reflectValue) + } } } - return association + return association.Error } -// Clear remove relationship between source & current associations, won't delete those associations -func (association *Association) Clear() *Association { +func (association *Association) Clear() error { return association.Replace() } -// Count return the count of current associations -func (association *Association) Count() int { +func (association *Association) Count() (count int64) { + if association.Error == nil { + association.Error = association.buildCondition().Count(&count).Error + } + return +} + +type assignBack struct { + Source reflect.Value + Index int + Dest reflect.Value +} + +func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( - count = 0 - relationship = association.field.Relationship - scope = association.scope - fieldValue = association.field.Field.Interface() - query = scope.DB() + reflectValue = association.DB.Statement.ReflectValue + assignBacks []assignBack // assign association values back to arguments after save ) - if relationship.Kind == "many_to_many" { - query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) - } else if relationship.Kind == "belongs_to" { - primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) - query = query.Where( - fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), - toQueryValues(primaryKeys)..., - ) + appendToRelations := func(source, rv reflect.Value, clear bool) { + switch association.Relationship.Type { + case schema.HasOne, schema.BelongsTo: + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() > 0 { + association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) + } + } + case reflect.Struct: + association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) + + if association.Relationship.Field.FieldType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) + } + } + case schema.HasMany, schema.Many2Many: + elemType := association.Relationship.Field.IndirectFieldType.Elem() + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) + if clear { + fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + } + + appendToFieldValues := func(ev reflect.Value) { + if ev.Type().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev) + } else if ev.Type().Elem().AssignableTo(elemType) { + fieldValue = reflect.Append(fieldValue, ev.Elem()) + } else { + association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + } + + if elemType.Kind() == reflect.Struct { + assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) + } + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) + } + case reflect.Struct: + appendToFieldValues(rv.Addr()) + } + + if association.Error == nil { + association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) + } + } } - if relationship.PolymorphicType != "" { - query = query.Where( - fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - relationship.PolymorphicValue, - ) + selectedSaveColumns := []string{association.Relationship.Name} + omitColumns := []string{} + selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, association.Relationship.Name) { + if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { + columnName = name + } + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selectedSaveColumns = append(selectedSaveColumns, columnName) + } else { + omitColumns = append(omitColumns, columnName) + } + } } - if err := query.Model(fieldValue).Count(&count).Error; err != nil { - association.Error = err + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey { + selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) + } + } + + associationDB := association.DB.Session(&Session{}).Model(nil) + if !association.DB.FullSaveAssociations { + associationDB.Select(selectedSaveColumns) + } + if len(omitColumns) > 0 { + associationDB.Omit(omitColumns...) + } + associationDB = associationDB.Session(&Session{}) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(values) != reflectValue.Len() { + // clear old data + if clear && len(values) == 0 { + for i := 0; i < reflectValue.Len(); i++ { + if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { + association.Error = err + break + } + + if association.Relationship.JoinTable == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { + association.Error = err + break + } + } + } + } + } + break + } + + association.Error = ErrInvalidValueOfLength + return + } + + for i := 0; i < reflectValue.Len(); i++ { + appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + + // TODO support save slice data, sql with case? + association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error + } + case reflect.Struct: + // clear old data + if clear && len(values) == 0 { + association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) + + if association.Relationship.JoinTable == nil && association.Error == nil { + for _, ref := range association.Relationship.References { + if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { + association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) + } + } + } + } + + for idx, value := range values { + rv := reflect.Indirect(reflect.ValueOf(value)) + appendToRelations(reflectValue, rv, clear && idx == 0) + } + + if len(values) > 0 { + association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error + } + } + + for _, assignBack := range assignBacks { + fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) + if assignBack.Index > 0 { + reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) + } else { + reflect.Indirect(assignBack.Dest).Set(fieldValue) + } } - return count } -// saveAssociations save passed values as associations -func (association *Association) saveAssociations(values ...interface{}) *Association { +func (association *Association) buildCondition() *DB { var ( - scope = association.scope - field = association.field - relationship = field.Relationship + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) ) - saveAssociation := func(reflectValue reflect.Value) { - // value has to been pointer - if reflectValue.Kind() != reflect.Ptr { - reflectPtr := reflect.New(reflectValue.Type()) - reflectPtr.Elem().Set(reflectValue) - reflectValue = reflectPtr - } - - // value has to been saved for many2many - if relationship.Kind == "many_to_many" { - if scope.New(reflectValue.Interface()).PrimaryKeyZero() { - association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error) + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } - // Assign Fields - var fieldType = field.Field.Type() - var setFieldBackToValue, setSliceFieldBackToValue bool - if reflectValue.Type().AssignableTo(fieldType) { - field.Set(reflectValue) - } else if reflectValue.Type().Elem().AssignableTo(fieldType) { - // if field's type is struct, then need to set value back to argument after save - setFieldBackToValue = true - field.Set(reflectValue.Elem()) - } else if fieldType.Kind() == reflect.Slice { - if reflectValue.Type().AssignableTo(fieldType.Elem()) { - field.Set(reflect.Append(field.Field, reflectValue)) - } else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) { - // if field's type is slice of struct, then need to set value back to argument after save - setSliceFieldBackToValue = true - field.Set(reflect.Append(field.Field, reflectValue.Elem())) - } - } - - if relationship.Kind == "many_to_many" { - association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface())) - } else { - association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error) - - if setFieldBackToValue { - reflectValue.Elem().Set(field.Field) - } else if setSliceFieldBackToValue { - reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1)) - } - } + tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) } - for _, value := range values { - reflectValue := reflect.ValueOf(value) - indirectReflectValue := reflect.Indirect(reflectValue) - if indirectReflectValue.Kind() == reflect.Struct { - saveAssociation(reflectValue) - } else if indirectReflectValue.Kind() == reflect.Slice { - for i := 0; i < indirectReflectValue.Len(); i++ { - saveAssociation(indirectReflectValue.Index(i)) - } - } else { - association.setErr(errors.New("invalid value type")) - } - } - return association -} - -func (association *Association) setErr(err error) *Association { - if err != nil { - association.Error = err - } - return association + return tx } diff --git a/association_test.go b/association_test.go deleted file mode 100644 index c84f84ed..00000000 --- a/association_test.go +++ /dev/null @@ -1,907 +0,0 @@ -package gorm_test - -import ( - "fmt" - "os" - "reflect" - "sort" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestBelongsTo(t *testing.T) { - post := Post{ - Title: "post belongs to", - Body: "body belongs to", - Category: Category{Name: "Category 1"}, - MainCategory: Category{Name: "Main Category 1"}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - if post.Category.ID == 0 || post.MainCategory.ID == 0 { - t.Errorf("Category's primary key should be updated") - } - - if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 { - t.Errorf("post's foreign key should be updated") - } - - // Query - var category1 Category - DB.Model(&post).Association("Category").Find(&category1) - if category1.Name != "Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var mainCategory1 Category - DB.Model(&post).Association("MainCategory").Find(&mainCategory1) - if mainCategory1.Name != "Main Category 1" { - t.Errorf("Query belongs to relations with Association") - } - - var category11 Category - DB.Model(&post).Related(&category11) - if category11.Name != "Category 1" { - t.Errorf("Query belongs to relations with Related") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - if DB.Model(&post).Association("MainCategory").Count() != 1 { - t.Errorf("Post's main category count should be 1") - } - - // Append - var category2 = Category{ - Name: "Category 2", - } - DB.Model(&post).Association("Category").Append(&category2) - - if category2.ID == 0 { - t.Errorf("Category should has ID when created with Append") - } - - var category21 Category - DB.Model(&post).Related(&category21) - - if category21.Name != "Category 2" { - t.Errorf("Category should be updated with Append") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Replace - var category3 = Category{ - Name: "Category 3", - } - DB.Model(&post).Association("Category").Replace(&category3) - - if category3.ID == 0 { - t.Errorf("Category should has ID when created with Replace") - } - - var category31 Category - DB.Model(&post).Related(&category31) - if category31.Name != "Category 3" { - t.Errorf("Category should be updated with Replace") - } - - if DB.Model(&post).Association("Category").Count() != 1 { - t.Errorf("Post's category count should be 1") - } - - // Delete - DB.Model(&post).Association("Category").Delete(&category2) - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not delete any category when Delete a unrelated Category") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should not be reseted when Delete a unrelated Category") - } - - DB.Model(&post).Association("Category").Delete(&category3) - - if post.Category.Name != "" { - t.Errorf("Post's category should be reseted after Delete") - } - - var category41 Category - DB.Model(&post).Related(&category41) - if category41.Name != "" { - t.Errorf("Category should be deleted with Delete") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Delete, but got %v", count) - } - - // Clear - DB.Model(&post).Association("Category").Append(&Category{ - Name: "Category 2", - }) - - if DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should find category after append") - } - - if post.Category.Name == "" { - t.Errorf("Post's category should has value after Append") - } - - DB.Model(&post).Association("Category").Clear() - - if post.Category.Name != "" { - t.Errorf("Post's category should be cleared after Clear") - } - - if !DB.Model(&post).Related(&Category{}).RecordNotFound() { - t.Errorf("Should not find any category after Clear") - } - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after Clear, but got %v", count) - } - - // Check Association mode with soft delete - category6 := Category{ - Name: "Category 6", - } - DB.Model(&post).Association("Category").Append(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 after Append, but got %v", count) - } - - DB.Delete(&category6) - - if count := DB.Model(&post).Association("Category").Count(); count != 0 { - t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count) - } - - if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil { - t.Errorf("Post's category is not findable after Delete") - } - - if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 { - t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil { - t.Errorf("Post's category should be findable when query with Unscoped, got %v", err) - } -} - -func TestBelongsToOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileRefer"` - ProfileRefer int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestBelongsToOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Refer string - Name string - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"` - ProfileID int - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "belongs_to" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOne(t *testing.T) { - user := User{ - Name: "has one", - CreditCard: CreditCard{Number: "411111111111"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Error("Got errors when save user", err.Error()) - } - - if user.CreditCard.UserId.Int64 == 0 { - t.Errorf("CreditCard's foreign key should be updated") - } - - // Query - var creditCard1 CreditCard - DB.Model(&user).Related(&creditCard1) - - if creditCard1.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - var creditCard11 CreditCard - DB.Model(&user).Association("CreditCard").Find(&creditCard11) - - if creditCard11.Number != "411111111111" { - t.Errorf("Query has one relations with Related") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Append - var creditcard2 = CreditCard{ - Number: "411111111112", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard2) - - if creditcard2.ID == 0 { - t.Errorf("Creditcard should has ID when created with Append") - } - - var creditcard21 CreditCard - DB.Model(&user).Related(&creditcard21) - if creditcard21.Number != "411111111112" { - t.Errorf("CreditCard should be updated with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Replace - var creditcard3 = CreditCard{ - Number: "411111111113", - } - DB.Model(&user).Association("CreditCard").Replace(&creditcard3) - - if creditcard3.ID == 0 { - t.Errorf("Creditcard should has ID when created with Replace") - } - - var creditcard31 CreditCard - DB.Model(&user).Related(&creditcard31) - if creditcard31.Number != "411111111113" { - t.Errorf("CreditCard should be updated with Replace") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - // Delete - DB.Model(&user).Association("CreditCard").Delete(&creditcard2) - var creditcard4 CreditCard - DB.Model(&user).Related(&creditcard4) - if creditcard4.Number != "411111111113" { - t.Errorf("Should not delete credit card when Delete a unrelated CreditCard") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Delete(&creditcard3) - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should delete credit card with Delete") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Delete") - } - - // Clear - var creditcard5 = CreditCard{ - Number: "411111111115", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard5) - - if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Should added credit card with Append") - } - - if DB.Model(&user).Association("CreditCard").Count() != 1 { - t.Errorf("User's credit card count should be 1") - } - - DB.Model(&user).Association("CreditCard").Clear() - if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() { - t.Errorf("Credit card should be deleted with Clear") - } - - if DB.Model(&user).Association("CreditCard").Count() != 0 { - t.Errorf("User's credit card count should be 0 after Clear") - } - - // Check Association mode with soft delete - var creditcard6 = CreditCard{ - Number: "411111111116", - } - DB.Model(&user).Association("CreditCard").Append(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 after Append, but got %v", count) - } - - DB.Delete(&creditcard6) - - if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 { - t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count) - } - - if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil { - t.Errorf("User's creditcard is not findable after Delete") - } - - if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 { - t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count) - } - - if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil { - t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err) - } -} - -func TestHasOneOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasOneOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_one" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasMany(t *testing.T) { - post := Post{ - Title: "post has many", - Body: "body has many", - Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}}, - } - - if err := DB.Save(&post).Error; err != nil { - t.Error("Got errors when save post", err) - } - - for _, comment := range post.Comments { - if comment.PostId == 0 { - t.Errorf("comment's PostID should be updated") - } - } - - var compareComments = func(comments []Comment, contents []string) bool { - var commentContents []string - for _, comment := range comments { - commentContents = append(commentContents, comment.Content) - } - sort.Strings(commentContents) - sort.Strings(contents) - return reflect.DeepEqual(commentContents, contents) - } - - // Query - if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil { - t.Errorf("Comment 1 should be saved") - } - - var comments1 []Comment - DB.Model(&post).Association("Comments").Find(&comments1) - if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Association") - } - - var comments11 []Comment - DB.Model(&post).Related(&comments11) - if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) { - t.Errorf("Query has many relations with Related") - } - - if DB.Model(&post).Association("Comments").Count() != 2 { - t.Errorf("Post's comments count should be 2") - } - - // Append - DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"}) - - var comments2 []Comment - DB.Model(&post).Related(&comments2) - if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) { - t.Errorf("Append new record to has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 3 { - t.Errorf("Post's comments count should be 3 after Append") - } - - // Delete - DB.Model(&post).Association("Comments").Delete(comments11) - - var comments3 []Comment - DB.Model(&post).Related(&comments3) - if !compareComments(comments3, []string{"Comment 3"}) { - t.Errorf("Delete an existing resource for has many relations") - } - - if DB.Model(&post).Association("Comments").Count() != 1 { - t.Errorf("Post's comments count should be 1 after Delete 2") - } - - // Replace - DB.Model(&Post{Id: 999}).Association("Comments").Replace() - - var comments4 []Comment - DB.Model(&post).Related(&comments4) - if len(comments4) == 0 { - t.Errorf("Replace for other resource should not clear all comments") - } - - DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"}) - - var comments41 []Comment - DB.Model(&post).Related(&comments41) - if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) { - t.Errorf("Replace has many relations") - } - - // Clear - DB.Model(&Post{Id: 999}).Association("Comments").Clear() - - var comments5 []Comment - DB.Model(&post).Related(&comments5) - if len(comments5) == 0 { - t.Errorf("Clear should not clear all comments") - } - - DB.Model(&post).Association("Comments").Clear() - - var comments51 []Comment - DB.Model(&post).Related(&comments51) - if len(comments51) != 0 { - t.Errorf("Clear has many relations") - } - - // Check Association mode with soft delete - var comment6 = Comment{ - Content: "comment 6", - } - DB.Model(&post).Association("Comments").Append(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 after Append, but got %v", count) - } - - DB.Delete(&comment6) - - if count := DB.Model(&post).Association("Comments").Count(); count != 0 { - t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count) - } - - var comments6 []Comment - if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 { - t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6)) - } - - if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count) - } - - var comments61 []Comment - if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 { - t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61)) - } -} - -func TestHasManyOverrideForeignKey1(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserRefer uint - } - - type User struct { - gorm.Model - Profile []Profile `gorm:"ForeignKey:UserRefer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestHasManyOverrideForeignKey2(t *testing.T) { - type Profile struct { - gorm.Model - Name string - UserID uint - } - - type User struct { - gorm.Model - Refer string - Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"` - } - - if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok { - if relation.Relationship.Kind != "has_many" || - !reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) || - !reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) { - t.Errorf("Override belongs to foreign key with tag") - } - } -} - -func TestManyToMany(t *testing.T) { - DB.Raw("delete from languages") - var languages = []Language{{Name: "ZH"}, {Name: "EN"}} - user := User{Name: "Many2Many", Languages: languages} - DB.Save(&user) - - // Query - var newLanguages []Language - DB.Model(&user).Related(&newLanguages, "Languages") - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Query many to many relations") - } - - DB.Model(&user).Association("Languages").Find(&newLanguages) - if len(newLanguages) != len([]string{"ZH", "EN"}) { - t.Errorf("Should be able to find many to many relations") - } - - if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) { - t.Errorf("Count should return correct result") - } - - // Append - DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"}) - if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() { - t.Errorf("New record should be saved when append") - } - - languageA := Language{Name: "AA"} - DB.Save(&languageA) - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA) - - languageC := Language{Name: "CC"} - DB.Save(&languageC) - DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC}) - - DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}}) - - totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"} - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) { - t.Errorf("All appended languages should be saved") - } - - // Delete - user.Languages = []Language{} - DB.Model(&user).Association("Languages").Find(&user.Languages) - - var language Language - DB.Where("name = ?", "EE").First(&language) - DB.Model(&user).Association("Languages").Delete(language, &language) - - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { - t.Errorf("Relations should be deleted with Delete") - } - if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() { - t.Errorf("Language EE should not be deleted") - } - - DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) - - user2 := User{Name: "Many2Many_User2", Languages: languages} - DB.Save(&user2) - - DB.Model(&user).Association("Languages").Delete(languages, &languages) - if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 { - t.Errorf("Relations should be deleted with Delete") - } - - if DB.Model(&user2).Association("Languages").Count() == 0 { - t.Errorf("Other user's relations should not be deleted") - } - - // Replace - var languageB Language - DB.Where("name = ?", "BB").First(&languageB) - DB.Model(&user).Association("Languages").Replace(languageB) - if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 { - t.Errorf("Relations should be replaced") - } - - DB.Model(&user).Association("Languages").Replace() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be replaced with empty") - } - - DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}}) - if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) { - t.Errorf("Relations should be replaced") - } - - // Clear - DB.Model(&user).Association("Languages").Clear() - if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 { - t.Errorf("Relations should be cleared") - } - - // Check Association mode with soft delete - var language6 = Language{ - Name: "language 6", - } - DB.Model(&user).Association("Languages").Append(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 after Append, but got %v", count) - } - - DB.Delete(&language6) - - if count := DB.Model(&user).Association("Languages").Count(); count != 0 { - t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count) - } - - var languages6 []Language - if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 { - t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6)) - } - - if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count) - } - - var languages61 []Language - if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 { - t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61)) - } -} - -func TestRelated(t *testing.T) { - user := User{ - Name: "jinzhu", - BillingAddress: Address{Address1: "Billing Address - Address 1"}, - ShippingAddress: Address{Address1: "Shipping Address - Address 1"}, - Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}}, - CreditCard: CreditCard{Number: "1234567890"}, - Company: Company{Name: "company1"}, - } - - if err := DB.Save(&user).Error; err != nil { - t.Errorf("No error should happen when saving user") - } - - if user.CreditCard.ID == 0 { - t.Errorf("After user save, credit card should have id") - } - - if user.BillingAddress.ID == 0 { - t.Errorf("After user save, billing address should have id") - } - - if user.Emails[0].Id == 0 { - t.Errorf("After user save, billing address should have id") - } - - var emails []Email - DB.Model(&user).Related(&emails) - if len(emails) != 2 { - t.Errorf("Should have two emails") - } - - var emails2 []Email - DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2) - if len(emails2) != 1 { - t.Errorf("Should have two emails") - } - - var emails3 []*Email - DB.Model(&user).Related(&emails3) - if len(emails3) != 2 { - t.Errorf("Should have two emails") - } - - var user1 User - DB.Model(&user).Related(&user1.Emails) - if len(user1.Emails) != 2 { - t.Errorf("Should have only one email match related condition") - } - - var address1 Address - DB.Model(&user).Related(&address1, "BillingAddressId") - if address1.Address1 != "Billing Address - Address 1" { - t.Errorf("Should get billing address from user correctly") - } - - user1 = User{} - DB.Model(&address1).Related(&user1, "BillingAddressId") - if DB.NewRecord(user1) { - t.Errorf("Should get user from address correctly") - } - - var user2 User - DB.Model(&emails[0]).Related(&user2) - if user2.Id != user.Id || user2.Name != user.Name { - t.Errorf("Should get user from email correctly") - } - - var creditcard CreditCard - var user3 User - DB.First(&creditcard, "number = ?", "1234567890") - DB.Model(&creditcard).Related(&user3) - if user3.Id != user.Id || user3.Name != user.Name { - t.Errorf("Should get user from credit card correctly") - } - - if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() { - t.Errorf("RecordNotFound for Related") - } - - var company Company - if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" { - t.Errorf("RecordNotFound for Related") - } -} - -func TestForeignKey(t *testing.T) { - for _, structField := range DB.NewScope(&User{}).GetStructFields() { - for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Email{}).GetStructFields() { - for _, foreignKey := range []string{"UserId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Post{}).GetStructFields() { - for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } - - for _, structField := range DB.NewScope(&Comment{}).GetStructFields() { - for _, foreignKey := range []string{"PostId"} { - if structField.Name == foreignKey && !structField.IsForeignKey { - t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey)) - } - } - } -} - -func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { - // sqlite does not support ADD CONSTRAINT in ALTER TABLE - return - } - targetScope := DB.NewScope(target) - targetTableName := targetScope.TableName() - modelScope := DB.NewScope(source) - modelField, ok := modelScope.FieldByName(sourceFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName)) - } - targetField, ok := targetScope.FieldByName(targetFieldName) - if !ok { - t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName)) - } - dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName) - err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error - if err != nil { - t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err)) - } -} - -func TestLongForeignKey(t *testing.T) { - testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID") -} - -func TestLongForeignKeyWithShortDest(t *testing.T) { - testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID") -} - -func TestHasManyChildrenWithOneStruct(t *testing.T) { - category := Category{ - Name: "main", - Categories: []Category{ - {Name: "sub1"}, - {Name: "sub2"}, - }, - } - - DB.Save(&category) -} - -func TestSkipSaveAssociation(t *testing.T) { - type Company struct { - gorm.Model - Name string - } - - type User struct { - gorm.Model - Name string - CompanyID uint - Company Company `gorm:"save_associations:false"` - } - DB.AutoMigrate(&Company{}, &User{}) - - DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}}) - - if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() { - t.Errorf("Company skip_save_association should not been saved") - } -} diff --git a/callback.go b/callback.go deleted file mode 100644 index a4382147..00000000 --- a/callback.go +++ /dev/null @@ -1,242 +0,0 @@ -package gorm - -import "log" - -// DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} - -// Callback is a struct that contains all CRUD callbacks -// Field `creates` contains callbacks will be call when creating object -// Field `updates` contains callbacks will be call when updating object -// Field `deletes` contains callbacks will be call when deleting object -// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association... -// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... -// Field `processors` contains all callback processors, will be used to generate above callbacks in order -type Callback struct { - creates []*func(scope *Scope) - updates []*func(scope *Scope) - deletes []*func(scope *Scope) - queries []*func(scope *Scope) - rowQueries []*func(scope *Scope) - processors []*CallbackProcessor -} - -// CallbackProcessor contains callback informations -type CallbackProcessor struct { - name string // current callback's name - before string // register current callback before a callback - after string // register current callback after a callback - replace bool // replace callbacks with same name - remove bool // delete callbacks with same name - kind string // callback type: create, update, delete, query, row_query - processor *func(scope *Scope) // callback handler - parent *Callback -} - -func (c *Callback) clone() *Callback { - return &Callback{ - creates: c.creates, - updates: c.updates, - deletes: c.deletes, - queries: c.queries, - rowQueries: c.rowQueries, - processors: c.processors, - } -} - -// Create could be used to register callbacks for creating object -// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) { -// // business logic -// ... -// -// // set error if some thing wrong happened, will rollback the creating -// scope.Err(errors.New("error")) -// }) -func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} -} - -// Update could be used to register callbacks for updating object, refer `Create` for usage -func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} -} - -// Delete could be used to register callbacks for deleting object, refer `Create` for usage -func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} -} - -// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... -// Refer `Create` for usage -func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} -} - -// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage -func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} -} - -// After insert a new callback after callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor { - cp.after = callbackName - return cp -} - -// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create` -func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { - cp.before = callbackName - return cp -} - -// Register a new callback, refer `Callbacks.Create` -func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { - if cp.kind == "row_query" { - if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { - log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName) - cp.before = "gorm:row_query" - } - } - - cp.name = callbackName - cp.processor = &callback - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Remove a registered callback -// db.Callback().Create().Remove("gorm:update_time_stamp_when_create") -func (cp *CallbackProcessor) Remove(callbackName string) { - log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.remove = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Replace a registered callback with new callback -// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) -// }) -func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) - cp.name = callbackName - cp.processor = &callback - cp.replace = true - cp.parent.processors = append(cp.parent.processors, cp) - cp.parent.reorder() -} - -// Get registered callback -// db.Callback().Create().Get("gorm:create") -func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { - for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor - } - } - return nil -} - -// getRIndex get right index from string slice -func getRIndex(strs []string, str string) int { - for i := len(strs) - 1; i >= 0; i-- { - if strs[i] == str { - return i - } - } - return -1 -} - -// sortProcessors sort callback processors based on its before, after, remove, replace -func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { - var ( - allNames, sortedNames []string - sortCallbackProcessor func(c *CallbackProcessor) - ) - - for _, cp := range cps { - // show warning message the callback name already exists - if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) - } - allNames = append(allNames, cp.name) - } - - sortCallbackProcessor = func(c *CallbackProcessor) { - if getRIndex(sortedNames, c.name) == -1 { // if not sorted - if c.before != "" { // if defined before callback - if index := getRIndex(sortedNames, c.before); index != -1 { - // if before callback already sorted, append current callback just after it - sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...) - } else if index := getRIndex(allNames, c.before); index != -1 { - // if before callback exists but haven't sorted, append current callback to last - sortedNames = append(sortedNames, c.name) - sortCallbackProcessor(cps[index]) - } - } - - if c.after != "" { // if defined after callback - if index := getRIndex(sortedNames, c.after); index != -1 { - // if after callback already sorted, append current callback just before it - sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...) - } else if index := getRIndex(allNames, c.after); index != -1 { - // if after callback exists but haven't sorted - cp := cps[index] - // set after callback's before callback to current callback - if cp.before == "" { - cp.before = c.name - } - sortCallbackProcessor(cp) - } - } - - // if current callback haven't been sorted, append it to last - if getRIndex(sortedNames, c.name) == -1 { - sortedNames = append(sortedNames, c.name) - } - } - } - - for _, cp := range cps { - sortCallbackProcessor(cp) - } - - var sortedFuncs []*func(scope *Scope) - for _, name := range sortedNames { - if index := getRIndex(allNames, name); !cps[index].remove { - sortedFuncs = append(sortedFuncs, cps[index].processor) - } - } - - return sortedFuncs -} - -// reorder all registered processors, and reset CRUD callbacks -func (c *Callback) reorder() { - var creates, updates, deletes, queries, rowQueries []*CallbackProcessor - - for _, processor := range c.processors { - if processor.name != "" { - switch processor.kind { - case "create": - creates = append(creates, processor) - case "update": - updates = append(updates, processor) - case "delete": - deletes = append(deletes, processor) - case "query": - queries = append(queries, processor) - case "row_query": - rowQueries = append(rowQueries, processor) - } - } - } - - c.creates = sortProcessors(creates) - c.updates = sortProcessors(updates) - c.deletes = sortProcessors(deletes) - c.queries = sortProcessors(queries) - c.rowQueries = sortProcessors(rowQueries) -} diff --git a/callback_create.go b/callback_create.go deleted file mode 100644 index a4da39e8..00000000 --- a/callback_create.go +++ /dev/null @@ -1,163 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -// Define callbacks for creating -func init() { - DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) - DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) - DefaultCallback.Create().Register("gorm:create", createCallback) - DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) - DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) - DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating -func beforeCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeCreate") - } -} - -// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating -func updateTimeStampForCreateCallback(scope *Scope) { - if !scope.HasError() { - now := NowFunc() - - if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { - if createdAtField.IsBlank { - createdAtField.Set(now) - } - } - - if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { - if updatedAtField.IsBlank { - updatedAtField.Set(now) - } - } - } -} - -// createCallback the callback used to insert data into database -func createCallback(scope *Scope) { - if !scope.HasError() { - defer scope.trace(NowFunc()) - - var ( - columns, placeholders []string - blankColumnsWithDefaultValue []string - ) - - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if field.IsNormal { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else if !field.IsPrimaryKey || !field.IsBlank { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } - } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { - for _, foreignKey := range field.Relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - columns = append(columns, scope.Quote(foreignField.DBName)) - placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) - } - } - } - } - } - - var ( - returningColumn = "*" - quotedTableName = scope.QuotedTableName() - primaryField = scope.PrimaryField() - extraOption string - ) - - if str, ok := scope.Get("gorm:insert_option"); ok { - extraOption = fmt.Sprint(str) - } - - if primaryField != nil { - returningColumn = scope.Quote(primaryField.DBName) - } - - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) - - if len(columns) == 0 { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", - quotedTableName, - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } else { - scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", - scope.QuotedTableName(), - strings.Join(columns, ","), - strings.Join(placeholders, ","), - addExtraSpaceIfExist(extraOption), - addExtraSpaceIfExist(lastInsertIDReturningSuffix), - )) - } - - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - // set rows affected count - scope.db.RowsAffected, _ = result.RowsAffected() - - // set primary value to primary field - if primaryField != nil && primaryField.IsBlank { - if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { - scope.Err(primaryField.Set(primaryValue)) - } - } - } - } else { - if primaryField.Field.CanAddr() { - if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { - primaryField.IsBlank = false - scope.db.RowsAffected = 1 - } - } else { - scope.Err(ErrUnaddressable) - } - } - } -} - -// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object -func forceReloadAfterCreateCallback(scope *Scope) { - if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) - for _, field := range scope.Fields() { - if field.IsPrimaryKey && !field.IsBlank { - db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) - } - } - db.Scan(scope.Value) - } -} - -// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating -func afterCreateCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterCreate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } -} diff --git a/callback_delete.go b/callback_delete.go deleted file mode 100644 index 73d90880..00000000 --- a/callback_delete.go +++ /dev/null @@ -1,63 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" -) - -// Define callbacks for deleting -func init() { - DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback) - DefaultCallback.Delete().Register("gorm:delete", deleteCallback) - DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback) - DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// beforeDeleteCallback will invoke `BeforeDelete` method before deleting -func beforeDeleteCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while deleting")) - return - } - if !scope.HasError() { - scope.CallMethod("BeforeDelete") - } -} - -// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete) -func deleteCallback(scope *Scope) { - if !scope.HasError() { - var extraOption string - if str, ok := scope.Get("gorm:delete_option"); ok { - extraOption = fmt.Sprint(str) - } - - deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") - - if !scope.Search.Unscoped && hasDeletedAtField { - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v=%v%v%v", - scope.QuotedTableName(), - scope.Quote(deletedAtField.DBName), - scope.AddToVars(NowFunc()), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } else { - scope.Raw(fmt.Sprintf( - "DELETE FROM %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(scope.CombinedConditionSql()), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterDeleteCallback will invoke `AfterDelete` method after deleting -func afterDeleteCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterDelete") - } -} diff --git a/callback_query.go b/callback_query.go deleted file mode 100644 index 20e88161..00000000 --- a/callback_query.go +++ /dev/null @@ -1,95 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" -) - -// Define callbacks for querying -func init() { - DefaultCallback.Query().Register("gorm:query", queryCallback) - DefaultCallback.Query().Register("gorm:preload", preloadCallback) - DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback) -} - -// queryCallback used to query data from database -func queryCallback(scope *Scope) { - defer scope.trace(NowFunc()) - - var ( - isSlice, isPtr bool - resultType reflect.Type - results = scope.IndirectValue() - ) - - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { - if primaryField := scope.PrimaryField(); primaryField != nil { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy)) - } - } - - if value, ok := scope.Get("gorm:query_destination"); ok { - results = indirect(reflect.ValueOf(value)) - } - - if kind := results.Kind(); kind == reflect.Slice { - isSlice = true - resultType = results.Type().Elem() - results.Set(reflect.MakeSlice(results.Type(), 0, 0)) - - if resultType.Kind() == reflect.Ptr { - isPtr = true - resultType = resultType.Elem() - } - } else if kind != reflect.Struct { - scope.Err(errors.New("unsupported destination, should be slice or struct")) - return - } - - scope.prepareQuerySQL() - - if !scope.HasError() { - scope.db.RowsAffected = 0 - if str, ok := scope.Get("gorm:query_option"); ok { - scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) - } - - if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - scope.db.RowsAffected++ - - elem := results - if isSlice { - elem = reflect.New(resultType).Elem() - } - - scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields()) - - if isSlice { - if isPtr { - results.Set(reflect.Append(results, elem.Addr())) - } else { - results.Set(reflect.Append(results, elem)) - } - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } else if scope.db.RowsAffected == 0 && !isSlice { - scope.Err(ErrRecordNotFound) - } - } - } -} - -// afterQueryCallback will invoke `AfterFind` method after querying -func afterQueryCallback(scope *Scope) { - if !scope.HasError() { - scope.CallMethod("AfterFind") - } -} diff --git a/callback_query_preload.go b/callback_query_preload.go deleted file mode 100644 index 21ab22ce..00000000 --- a/callback_query_preload.go +++ /dev/null @@ -1,380 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strconv" - "strings" -) - -// preloadCallback used to preload associations -func preloadCallback(scope *Scope) { - - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) - } - - if scope.Search.preload == nil || scope.HasError() { - return - } - - var ( - preloadedMap = map[string]bool{} - fields = scope.Fields() - ) - - for _, preload := range scope.Search.preload { - var ( - preloadFields = strings.Split(preload.schema, ".") - currentScope = scope - currentFields = fields - ) - - for idx, preloadField := range preloadFields { - var currentPreloadConditions []interface{} - - if currentScope == nil { - continue - } - - // if not preloaded - if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { - - // assign search conditions to last preload - if idx == len(preloadFields)-1 { - currentPreloadConditions = preload.conditions - } - - for _, field := range currentFields { - if field.Name != preloadField || field.Relationship == nil { - continue - } - - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, currentPreloadConditions) - case "has_many": - currentScope.handleHasManyPreload(field, currentPreloadConditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, currentPreloadConditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, currentPreloadConditions) - default: - scope.Err(errors.New("unsupported relation")) - } - - preloadedMap[preloadKey] = true - break - } - - if !preloadedMap[preloadKey] { - scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) - return - } - } - - // preload next level - if idx < len(preloadFields)-1 { - currentScope = currentScope.getColumnAsScope(preloadField) - if currentScope != nil { - currentFields = currentScope.Fields() - } - } - } - } -} - -func autoPreload(scope *Scope) { - for _, field := range scope.Fields() { - if field.Relationship == nil { - continue - } - - if val, ok := field.TagSettings["PRELOAD"]; ok { - if preload, err := strconv.ParseBool(val); err != nil { - scope.Err(errors.New("invalid preload option")) - return - } else if !preload { - continue - } - } - - scope.Search.Preload(field.Name) - } -} - -func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { - var ( - preloadDB = scope.NewDB() - preloadConditions []interface{} - ) - - for _, condition := range conditions { - if scopes, ok := condition.(func(*DB) *DB); ok { - preloadDB = scopes(preloadDB) - } else { - preloadConditions = append(preloadConditions, condition) - } - } - - return preloadDB, preloadConditions -} - -// handleHasOnePreload used to preload has one associations -func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } - } - } - } else { - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - scope.Err(field.Set(result)) - } - } -} - -// handleHasManyPreload used to preload has many associations -func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // find relations - query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) - values := toQueryValues(primaryKeys) - if relation.PolymorphicType != "" { - query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) - values = append(values, relation.PolymorphicValue) - } - - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - if indirectScopeValue.Kind() == reflect.Slice { - preloadMap := make(map[string][]reflect.Value) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) - } - - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - f := object.FieldByName(field.Name) - if results, ok := preloadMap[toString(objectRealValue)]; ok { - f.Set(reflect.Append(f, results...)) - } else { - f.Set(reflect.MakeSlice(f.Type(), 0, 0)) - } - } - } else { - scope.Err(field.Set(resultsValue)) - } -} - -// handleBelongsToPreload used to preload belongs to associations -func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // get relations's primary keys - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) - if len(primaryKeys) == 0 { - return - } - - // find relations - results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) - - // assign find results - var ( - resultsValue = indirect(reflect.ValueOf(results)) - indirectScopeValue = scope.IndirectValue() - ) - - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.Err(field.Set(result)) - } - } -} - -// handleManyToManyPreload used to preload many to many associations -func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - var ( - relation = field.Relationship - joinTableHandler = relation.JoinTableHandler - fieldType = field.Struct.Type.Elem() - foreignKeyValue interface{} - foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() - linkHash = map[string][]reflect.Value{} - isPtr bool - ) - - if fieldType.Kind() == reflect.Ptr { - isPtr = true - fieldType = fieldType.Elem() - } - - var sourceKeys = []string{} - for _, key := range joinTableHandler.SourceForeignKeys() { - sourceKeys = append(sourceKeys, key.DBName) - } - - // preload conditions - preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) - - // generate query with join table - newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) - - if len(preloadDB.search.selects) == 0 { - preloadDB = preloadDB.Select("*") - } - - preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) - - // preload inline conditions - if len(preloadConditions) > 0 { - preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) - } - - rows, err := preloadDB.Rows() - - if scope.Err(err) != nil { - return - } - defer rows.Close() - - columns, _ := rows.Columns() - for rows.Next() { - var ( - elem = reflect.New(fieldType).Elem() - fields = scope.New(elem.Addr().Interface()).Fields() - ) - - // register foreign keys in join tables - var joinTableFields []*Field - for _, sourceKey := range sourceKeys { - joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) - } - - scope.scan(rows, columns, append(fields, joinTableFields...)) - - var foreignKeys = make([]interface{}, len(sourceKeys)) - // generate hashed forkey keys in join table - for idx, joinTableField := range joinTableFields { - if !joinTableField.Field.IsNil() { - foreignKeys[idx] = joinTableField.Field.Elem().Interface() - } - } - hashedSourceKeys := toString(foreignKeys) - - if isPtr { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) - } else { - linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) - } - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - - // assign find results - var ( - indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string][]reflect.Value{} - foreignFieldNames = []string{} - ) - - for _, dbName := range relation.ForeignFieldNames { - if field, ok := scope.FieldByName(dbName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - if indirectScopeValue.Kind() == reflect.Slice { - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - key := toString(getValueFromFields(object, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) - } - } else if indirectScopeValue.IsValid() { - key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) - fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) - } - for source, link := range linkHash { - for i, field := range fieldsSourceMap[source] { - //If not 0 this means Value is a pointer and we already added preloaded models to it - if fieldsSourceMap[source][i].Len() != 0 { - continue - } - field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) - } - - } -} diff --git a/callback_row_query.go b/callback_row_query.go deleted file mode 100644 index c2ff4a08..00000000 --- a/callback_row_query.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import "database/sql" - -// Define callbacks for row query -func init() { - DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) -} - -type RowQueryResult struct { - Row *sql.Row -} - -type RowsQueryResult struct { - Rows *sql.Rows - Error error -} - -// queryCallback used to query data from database -func rowQueryCallback(scope *Scope) { - if result, ok := scope.InstanceGet("row_query_result"); ok { - scope.prepareQuerySQL() - - if rowResult, ok := result.(*RowQueryResult); ok { - rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) - } else if rowsResult, ok := result.(*RowsQueryResult); ok { - rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) - } - } -} diff --git a/callback_save.go b/callback_save.go deleted file mode 100644 index f4bc918e..00000000 --- a/callback_save.go +++ /dev/null @@ -1,99 +0,0 @@ -package gorm - -import "reflect" - -func beginTransactionCallback(scope *Scope) { - scope.Begin() -} - -func commitOrRollbackTransactionCallback(scope *Scope) { - scope.CommitOrRollback() -} - -func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") { - if relationship := field.Relationship; relationship != nil { - return true, relationship - } - } - } - return false, nil -} - -func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() - scope.Err(scope.NewDB().Save(fieldValue).Error) - if len(relationship.ForeignFieldNames) != 0 { - // set value's foreign key - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface())) - } - } - } - } - } -} - -func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } - for _, field := range scope.Fields() { - if ok, relationship := saveFieldAsAssociation(scope, field); ok && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) - - if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - - scope.Err(newDB.Save(elem).Error) - - if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) - } - } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) - if len(relationship.ForeignFieldNames) != 0 { - for idx, fieldName := range relationship.ForeignFieldNames { - associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.FieldByName(associationForeignName); ok { - scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) - } - } - } - - if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) - } - scope.Err(scope.NewDB().Save(elem).Error) - } - } - } -} diff --git a/callback_system_test.go b/callback_system_test.go deleted file mode 100644 index 13ca3f42..00000000 --- a/callback_system_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package gorm - -import ( - "reflect" - "runtime" - "strings" - "testing" -) - -func equalFuncs(funcs []*func(s *Scope), fnames []string) bool { - var names []string - for _, f := range funcs { - fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".") - names = append(names, fnames[len(fnames)-1]) - } - return reflect.DeepEqual(names, fnames) -} - -func create(s *Scope) {} -func beforeCreate1(s *Scope) {} -func beforeCreate2(s *Scope) {} -func afterCreate1(s *Scope) {} -func afterCreate2(s *Scope) {} - -func TestRegisterCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("before_create2", beforeCreate2) - callback.Create().Register("create", create) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Register("after_create2", afterCreate2) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback") - } -} - -func TestRegisterCallbackWithOrder(t *testing.T) { - var callback1 = &Callback{} - callback1.Create().Register("before_create1", beforeCreate1) - callback1.Create().Register("create", create) - callback1.Create().Register("after_create1", afterCreate1) - callback1.Create().Before("after_create1").Register("after_create2", afterCreate2) - if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Update().Register("create", create) - callback2.Update().Before("create").Register("before_create1", beforeCreate1) - callback2.Update().After("after_create2").Register("after_create1", afterCreate1) - callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2) - callback2.Update().Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) { - t.Errorf("register callback with order") - } -} - -func TestRegisterCallbackWithComplexOrder(t *testing.T) { - var callback1 = &Callback{} - - callback1.Query().Before("after_create1").After("before_create1").Register("create", create) - callback1.Query().Register("before_create1", beforeCreate1) - callback1.Query().Register("after_create1", afterCreate1) - - if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) { - t.Errorf("register callback with order") - } - - var callback2 = &Callback{} - - callback2.Delete().Before("after_create1").After("before_create1").Register("create", create) - callback2.Delete().Before("create").Register("before_create1", beforeCreate1) - callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2) - callback2.Delete().Register("after_create1", afterCreate1) - callback2.Delete().After("after_create1").Register("after_create2", afterCreate2) - - if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) { - t.Errorf("register callback with order") - } -} - -func replaceCreate(s *Scope) {} - -func TestReplaceCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Replace("create", replaceCreate) - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) { - t.Errorf("replace callback") - } -} - -func TestRemoveCallback(t *testing.T) { - var callback = &Callback{} - - callback.Create().Before("after_create1").After("before_create1").Register("create", create) - callback.Create().Register("before_create1", beforeCreate1) - callback.Create().Register("after_create1", afterCreate1) - callback.Create().Remove("create") - - if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) { - t.Errorf("remove callback") - } -} diff --git a/callback_update.go b/callback_update.go deleted file mode 100644 index 45fb9cc4..00000000 --- a/callback_update.go +++ /dev/null @@ -1,117 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "strings" -) - -// Define callbacks for updating -func init() { - DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback) - DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback) - DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback) - DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) - DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback) - DefaultCallback.Update().Register("gorm:update", updateCallback) - DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback) - DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback) - DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) -} - -// assignUpdatingAttributesCallback assign updating attributes to model -func assignUpdatingAttributesCallback(scope *Scope) { - if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok { - if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate { - scope.InstanceSet("gorm:update_attrs", updateMaps) - } else { - scope.SkipLeft() - } - } -} - -// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating -func beforeUpdateCallback(scope *Scope) { - if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { - scope.Err(errors.New("Missing WHERE clause while updating")) - return - } - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("BeforeSave") - } - if !scope.HasError() { - scope.CallMethod("BeforeUpdate") - } - } -} - -// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating -func updateTimeStampForUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) - } -} - -// updateCallback the callback used to update data to database -func updateCallback(scope *Scope) { - if !scope.HasError() { - var sqls []string - - if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) - } - } else { - for _, field := range scope.Fields() { - if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - for _, foreignKey := range relationship.ForeignDBNames { - if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { - sqls = append(sqls, - fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface()))) - } - } - } - } - } - } - - var extraOption string - if str, ok := scope.Get("gorm:update_option"); ok { - extraOption = fmt.Sprint(str) - } - - if len(sqls) > 0 { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - combinedSql := whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() - scope.Raw(fmt.Sprintf( - "UPDATE %v %v SET %v%v%v", - scope.QuotedTableName(), - addExtraSpaceIfExist(joinSQL), - strings.Join(sqls, ", "), - addExtraSpaceIfExist(combinedSql), - addExtraSpaceIfExist(extraOption), - )).Exec() - } - } -} - -// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating -func afterUpdateCallback(scope *Scope) { - if _, ok := scope.Get("gorm:update_column"); !ok { - if !scope.HasError() { - scope.CallMethod("AfterUpdate") - } - if !scope.HasError() { - scope.CallMethod("AfterSave") - } - } -} diff --git a/callbacks.go b/callbacks.go new file mode 100644 index 00000000..01d9ed30 --- /dev/null +++ b/callbacks.go @@ -0,0 +1,327 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "reflect" + "sort" + "time" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func initializeCallbacks(db *DB) *callbacks { + return &callbacks{ + processors: map[string]*processor{ + "create": {db: db}, + "query": {db: db}, + "update": {db: db}, + "delete": {db: db}, + "row": {db: db}, + "raw": {db: db}, + }, + } +} + +// callbacks gorm callbacks manager +type callbacks struct { + processors map[string]*processor +} + +type processor struct { + db *DB + Clauses []string + fns []func(*DB) + callbacks []*callback +} + +type callback struct { + name string + before string + after string + remove bool + replace bool + match func(*DB) bool + handler func(*DB) + processor *processor +} + +func (cs *callbacks) Create() *processor { + return cs.processors["create"] +} + +func (cs *callbacks) Query() *processor { + return cs.processors["query"] +} + +func (cs *callbacks) Update() *processor { + return cs.processors["update"] +} + +func (cs *callbacks) Delete() *processor { + return cs.processors["delete"] +} + +func (cs *callbacks) Row() *processor { + return cs.processors["row"] +} + +func (cs *callbacks) Raw() *processor { + return cs.processors["raw"] +} + +func (p *processor) Execute(db *DB) { + // call scopes + for len(db.Statement.scopes) > 0 { + scopes := db.Statement.scopes + db.Statement.scopes = nil + for _, scope := range scopes { + db = scope(db) + } + } + + var ( + curTime = time.Now() + stmt = db.Statement + resetBuildClauses bool + ) + + if len(stmt.BuildClauses) == 0 { + stmt.BuildClauses = p.Clauses + resetBuildClauses = true + } + + // assign model values + if stmt.Model == nil { + stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model + } + + // parse model values + if stmt.Model != nil { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } + } + } + + // assign stmt.ReflectValue + if stmt.Dest != nil { + stmt.ReflectValue = reflect.ValueOf(stmt.Dest) + for stmt.ReflectValue.Kind() == reflect.Ptr { + if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { + stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) + } + + stmt.ReflectValue = stmt.ReflectValue.Elem() + } + if !stmt.ReflectValue.IsValid() { + db.AddError(ErrInvalidValue) + } + } + + for _, f := range p.fns { + f(db) + } + + db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { + return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected + }, db.Error) + + if !stmt.DB.DryRun { + stmt.SQL.Reset() + stmt.Vars = nil + } + + if resetBuildClauses { + stmt.BuildClauses = nil + } +} + +func (p *processor) Get(name string) func(*DB) { + for i := len(p.callbacks) - 1; i >= 0; i-- { + if v := p.callbacks[i]; v.name == name && !v.remove { + return v.handler + } + } + return nil +} + +func (p *processor) Before(name string) *callback { + return &callback{before: name, processor: p} +} + +func (p *processor) After(name string) *callback { + return &callback{after: name, processor: p} +} + +func (p *processor) Match(fc func(*DB) bool) *callback { + return &callback{match: fc, processor: p} +} + +func (p *processor) Register(name string, fn func(*DB)) error { + return (&callback{processor: p}).Register(name, fn) +} + +func (p *processor) Remove(name string) error { + return (&callback{processor: p}).Remove(name) +} + +func (p *processor) Replace(name string, fn func(*DB)) error { + return (&callback{processor: p}).Replace(name, fn) +} + +func (p *processor) compile() (err error) { + var callbacks []*callback + for _, callback := range p.callbacks { + if callback.match == nil || callback.match(p.db) { + callbacks = append(callbacks, callback) + } + } + p.callbacks = callbacks + + if p.fns, err = sortCallbacks(p.callbacks); err != nil { + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) + } + return +} + +func (c *callback) Before(name string) *callback { + c.before = name + return c +} + +func (c *callback) After(name string) *callback { + c.after = name + return c +} + +func (c *callback) Register(name string, fn func(*DB)) error { + c.name = name + c.handler = fn + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Remove(name string) error { + c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.remove = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +func (c *callback) Replace(name string, fn func(*DB)) error { + c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.name = name + c.handler = fn + c.replace = true + c.processor.callbacks = append(c.processor.callbacks, c) + return c.processor.compile() +} + +// getRIndex get right index from string slice +func getRIndex(strs []string, str string) int { + for i := len(strs) - 1; i >= 0; i-- { + if strs[i] == str { + return i + } + } + return -1 +} + +func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { + var ( + names, sorted []string + sortCallback func(*callback) error + ) + sort.Slice(cs, func(i, j int) bool { + return cs[j].before == "*" || cs[j].after == "*" + }) + + for _, c := range cs { + // show warning message the callback name already exists + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + } + names = append(names, c.name) + } + + sortCallback = func(c *callback) error { + if c.before != "" { // if defined before callback + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if before callback already sorted, append current callback just after it + sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + } else if curIdx > sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + } + } else if idx := getRIndex(names, c.before); idx != -1 { + // if before callback exists + cs[idx].after = c.name + } + } + + if c.after != "" { // if defined after callback + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + // if after callback sorted, append current callback to last + sorted = append(sorted, c.name) + } else if curIdx < sortedIdx { + return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + } + } else if idx := getRIndex(names, c.after); idx != -1 { + // if after callback exists but haven't sorted + // set after callback's before callback to current callback + after := cs[idx] + + if after.before == "" { + after.before = c.name + } + + if err := sortCallback(after); err != nil { + return err + } + + if err := sortCallback(c); err != nil { + return err + } + } + } + + // if current callback haven't been sorted, append it to last + if getRIndex(sorted, c.name) == -1 { + sorted = append(sorted, c.name) + } + + return nil + } + + for _, c := range cs { + if err = sortCallback(c); err != nil { + return + } + } + + for _, name := range sorted { + if idx := getRIndex(names, name); !cs[idx].remove { + fns = append(fns, cs[idx].handler) + } + } + + return +} diff --git a/callbacks/associations.go b/callbacks/associations.go new file mode 100644 index 00000000..6d74f20d --- /dev/null +++ b/callbacks/associations.go @@ -0,0 +1,390 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func SaveBeforeAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Belongs To associations + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + setupReferences := func(obj reflect.Value, elem reflect.Value) { + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(elem) + db.AddError(ref.ForeignKey.Set(obj, pv)) + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + dest[ref.ForeignKey.DBName] = pv + if _, ok := dest[rel.Name]; ok { + dest[rel.Name] = elem.Interface() + } + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len()) + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } + } + } else { + break + } + } + + if elems.Len() > 0 { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { + for i := 0; i < elems.Len(); i++ { + setupReferences(objs[i], elems.Index(i)) + } + } + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { + setupReferences(db.Statement.ReflectValue, rv) + } + } + } + } + } + } +} + +func SaveAfterAssociations(create bool) func(db *gorm.DB) { + return func(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) + + // Save Has One associations + for _, rel := range db.Statement.Schema.Relationships.HasOne { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + fieldType = rel.Field.FieldType + isPtr = fieldType.Kind() == reflect.Ptr + ) + + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } + } + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + case reflect.Struct: + if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { + f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if f.Kind() != reflect.Ptr { + f = f.Addr() + } + + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) + ref.ForeignKey.Set(f, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(f, ref.PrimaryValue) + } + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) + } + } + } + + // Save Has Many associations + for _, rel := range db.Statement.Schema.Relationships.HasMany { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + pv, _ := ref.PrimaryKey.ValueOf(v) + ref.ForeignKey.Set(elem, pv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(elem, ref.PrimaryValue) + } + } + + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + if elems.Len() > 0 { + assignmentColumns := make([]string, 0, len(rel.References)) + for _, ref := range rel.References { + assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) + } + + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) + } + } + + // Save Many2Many associations + for _, rel := range db.Statement.Schema.Relationships.Many2Many { + if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { + continue + } + + fieldType := rel.Field.IndirectFieldType.Elem() + isPtr := fieldType.Kind() == reflect.Ptr + if !isPtr { + fieldType = reflect.PtrTo(fieldType) + } + elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) + objs := []reflect.Value{} + + appendToJoins := func(obj reflect.Value, elem reflect.Value) { + joinValue := reflect.New(rel.JoinTable.ModelType) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + ref.ForeignKey.Set(joinValue, fv) + } else if ref.PrimaryValue != "" { + ref.ForeignKey.Set(joinValue, ref.PrimaryValue) + } else { + fv, _ := ref.PrimaryKey.ValueOf(elem) + ref.ForeignKey.Set(joinValue, fv) + } + } + joins = reflect.Append(joins, joinValue) + } + + appendToElems := func(v reflect.Value) { + if _, zero := rel.Field.ValueOf(v); !zero { + f := reflect.Indirect(rel.Field.ReflectValueOf(v)) + + for i := 0; i < f.Len(); i++ { + elem := f.Index(i) + + objs = append(objs, v) + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } + } + } + } + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } + } + case reflect.Struct: + appendToElems(db.Statement.ReflectValue) + } + + // optimize elems of reflect value length + if elemLen := elems.Len(); elemLen > 0 { + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } + + for i := 0; i < elemLen; i++ { + appendToJoins(objs[i], elems.Index(i)) + } + } + + if joins.Len() > 0 { + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }).Create(joins.Interface()).Error) + } + } + } + } +} + +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { + continue + } + + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } + } + + if len(defaultUpdatingColumns) > 0 { + columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } + } + + return clause.OnConflict{DoNothing: true} +} + +func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + refName = rel.Name + "." + ) + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ + FullSaveAssociations: db.FullSaveAssociations, + SkipHooks: db.Statement.SkipHooks, + DisableNestedTransaction: true, + }) + + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if tx.Statement.FullSaveAssociations { + tx = tx.InstanceSet("gorm:update_track_time", true) + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } else if restricted && len(omits) == 0 { + tx = tx.Omit(clause.Associations) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go new file mode 100644 index 00000000..d85c1928 --- /dev/null +++ b/callbacks/callbacks.go @@ -0,0 +1,83 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +var ( + createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} + queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + updateClauses = []string{"UPDATE", "SET", "WHERE"} + deleteClauses = []string{"DELETE", "FROM", "WHERE"} +) + +type Config struct { + LastInsertIDReversed bool + WithReturning bool + CreateClauses []string + QueryClauses []string + UpdateClauses []string + DeleteClauses []string +} + +func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { + enableTransaction := func(db *gorm.DB) bool { + return !db.SkipDefaultTransaction + } + + createCallback := db.Callback().Create() + createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + createCallback.Register("gorm:before_create", BeforeCreate) + createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) + createCallback.Register("gorm:create", Create(config)) + createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) + createCallback.Register("gorm:after_create", AfterCreate) + createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.CreateClauses) == 0 { + config.CreateClauses = createClauses + } + createCallback.Clauses = config.CreateClauses + + queryCallback := db.Callback().Query() + queryCallback.Register("gorm:query", Query) + queryCallback.Register("gorm:preload", Preload) + queryCallback.Register("gorm:after_query", AfterQuery) + if len(config.QueryClauses) == 0 { + config.QueryClauses = queryClauses + } + queryCallback.Clauses = config.QueryClauses + + deleteCallback := db.Callback().Delete() + deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) + deleteCallback.Register("gorm:delete", Delete) + deleteCallback.Register("gorm:after_delete", AfterDelete) + deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.DeleteClauses) == 0 { + config.DeleteClauses = deleteClauses + } + deleteCallback.Clauses = config.DeleteClauses + + updateCallback := db.Callback().Update() + updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) + updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) + updateCallback.Register("gorm:before_update", BeforeUpdate) + updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) + updateCallback.Register("gorm:update", Update) + updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) + updateCallback.Register("gorm:after_update", AfterUpdate) + updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) + if len(config.UpdateClauses) == 0 { + config.UpdateClauses = updateClauses + } + updateCallback.Clauses = config.UpdateClauses + + rowCallback := db.Callback().Row() + rowCallback.Register("gorm:row", RowQuery) + rowCallback.Clauses = config.QueryClauses + + rawCallback := db.Callback().Raw() + rawCallback.Register("gorm:raw", RawExec) + rawCallback.Clauses = config.QueryClauses +} diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go new file mode 100644 index 00000000..bcaa03f3 --- /dev/null +++ b/callbacks/callmethod.go @@ -0,0 +1,23 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" +) + +func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { + tx := db.Session(&gorm.Session{NewDB: true}) + if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + db.Statement.CurDestIndex = 0 + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + db.Statement.CurDestIndex++ + } + case reflect.Struct: + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } + } +} diff --git a/callbacks/create.go b/callbacks/create.go new file mode 100644 index 00000000..727bd380 --- /dev/null +++ b/callbacks/create.go @@ -0,0 +1,370 @@ +package callbacks + +import ( + "fmt" + "reflect" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func BeforeCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(BeforeCreateInterface); ok { + called = true + db.AddError(i.BeforeCreate(tx)) + } + } + return called + }) + } +} + +func Create(config *Config) func(db *gorm.DB) { + if config.WithReturning { + return CreateWithReturning + } else { + return func(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + + if db.RowsAffected > 0 { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } + } + } else { + db.AddError(err) + } + } + } + } else { + db.AddError(err) + } + } + } + } + } +} + +func CreateWithReturning(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + + db.Statement.Build(db.Statement.BuildClauses...) + } + + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") + + var ( + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) + + for idx, field := range sch.FieldsWithDefaultDBValue { + if idx > 0 { + db.Statement.WriteByte(',') + } + + fields[idx] = field + db.Statement.WriteQuoted(field.DBName) + } + + if !db.DryRun && db.Error == nil { + db.RowsAffected = 0 + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + defer rows.Close() + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + c := db.Statement.Clauses["ON CONFLICT"] + onConflict, _ := c.Expression.(clause.OnConflict) + + for rows.Next() { + BEGIN: + reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + + for idx, field := range fields { + fieldValue := field.ReflectValueOf(reflectValue) + + if onConflict.DoNothing && !fieldValue.IsZero() { + db.RowsAffected++ + + if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { + return + } + + goto BEGIN + } + + values[idx] = fieldValue.Addr().Interface() + } + + db.RowsAffected++ + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + } + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() + } + + if rows.Next() { + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + } + } + } else { + db.AddError(err) + } + } + } else if !db.DryRun && db.Error == nil { + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterCreate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + if db.Statement.Schema.AfterCreate { + if i, ok := value.(AfterCreateInterface); ok { + called = true + db.AddError(i.AfterCreate(tx)) + } + } + return called + }) + } +} + +// ConvertToCreateValues convert to create values +func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { + switch value := stmt.Dest.(type) { + case map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) + case []map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) + default: + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + curTime = stmt.DB.NowFunc() + isZero bool + ) + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} + + for _, db := range stmt.Schema.DBNames { + if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { + if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { + values.Columns = append(values.Columns, clause.Column{Name: db}) + } + } + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + stmt.SQL.Grow(stmt.ReflectValue.Len() * 18) + values.Values = make([][]interface{}, stmt.ReflectValue.Len()) + defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + if stmt.ReflectValue.Len() == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + for i := 0; i < stmt.ReflectValue.Len(); i++ { + rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + + values.Values[i] = make([]interface{}, len(values.Columns)) + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { + if field.DefaultValueInterface != nil { + values.Values[i][idx] = field.DefaultValueInterface + field.Set(rv, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) + } + } else if field.AutoUpdateTime > 0 { + if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) + } + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(rv); !isZero { + if len(defaultValueFieldsHavingValue[field]) == 0 { + defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len()) + } + defaultValueFieldsHavingValue[field][i] = v + } + } + } + } + + for field, vs := range defaultValueFieldsHavingValue { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } + } + } + case reflect.Struct: + values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} + for idx, column := range values.Columns { + field := stmt.Schema.FieldsByDBName[column.Name] + if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { + if field.DefaultValueInterface != nil { + values.Values[0][idx] = field.DefaultValueInterface + field.Set(stmt.ReflectValue, field.DefaultValueInterface) + } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + } + } else if field.AutoUpdateTime > 0 { + if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) + } + } + } + + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + values.Values[0] = append(values.Values[0], v) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) > 1 { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } + } + } + + onConflict.DoUpdates = clause.AssignmentColumns(columns) + + // use primary fields as default OnConflict columns + if len(onConflict.Columns) == 0 { + for _, field := range stmt.Schema.PrimaryFields { + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) + } + } + stmt.AddClause(onConflict) + } + } + } + + return values +} diff --git a/callbacks/delete.go b/callbacks/delete.go new file mode 100644 index 00000000..91659c51 --- /dev/null +++ b/callbacks/delete.go @@ -0,0 +1,168 @@ +package callbacks + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func BeforeDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(BeforeDeleteInterface); ok { + db.AddError(i.BeforeDelete(tx)) + return true + } + + return false + }) + } +} + +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) + withoutConditions := false + if db.Statement.Unscoped { + tx = tx.Unscoped() + } + + if len(db.Statement.Selects) > 0 { + selects := make([]string, 0, len(db.Statement.Selects)) + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + case schema.Many2Many: + var ( + queryConds = make([]clause.Expression, 0, len(rel.References)) + foreignFields = make([]*schema.Field, 0, len(rel.References)) + relForeignKeys = make([]string, 0, len(rel.References)) + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + +func Delete(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + db.Statement.AddClauseIfNotExists(clause.Delete{}) + + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build(db.Statement.BuildClauses...) + } + + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterDelete(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterDeleteInterface); ok { + db.AddError(i.AfterDelete(tx)) + return true + } + return false + }) + } +} diff --git a/callbacks/helper.go b/callbacks/helper.go new file mode 100644 index 00000000..ad85a1c6 --- /dev/null +++ b/callbacks/helper.go @@ -0,0 +1,95 @@ +package callbacks + +import ( + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// ConvertMapToValuesForCreate convert map to values +func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { + values.Columns = make([]clause.Column, 0, len(mapValue)) + selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) + + var keys = make([]string, 0, len(mapValue)) + for k := range mapValue { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + value := mapValue[k] + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + + values.Values[0] = append(values.Values[0], value) + } + } + return +} + +// ConvertSliceOfMapToValuesForCreate convert slice of map to values +func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { + var ( + columns = make([]string, 0, len(mapValues)) + ) + + // when the length of mapValues,return directly here + // no need to call stmt.SelectAndOmitColumns method + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + + var ( + result = make(map[string][]interface{}, len(mapValues)) + selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + ) + + for idx, mapValue := range mapValues { + for k, v := range mapValue { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } + } + + if _, ok := result[k]; !ok { + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + result[k] = make([]interface{}, len(mapValues)) + columns = append(columns, k) + } else { + continue + } + } + + result[k][idx] = v + } + } + + sort.Strings(columns) + values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) + for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + + for i, v := range result[column] { + if len(values.Values[i]) == 0 { + values.Values[i] = make([]interface{}, len(columns)) + } + + values.Values[i][idx] = v + } + } + return +} diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/preload.go b/callbacks/preload.go new file mode 100644 index 00000000..25c5e659 --- /dev/null +++ b/callbacks/preload.go @@ -0,0 +1,164 @@ +package callbacks + +import ( + "reflect" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { + var ( + reflectValue = db.Statement.ReflectValue + tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + inlineConds []interface{} + ) + + db.Statement.Settings.Range(func(k, v interface{}) bool { + tx.Statement.Settings.Store(k, v) + return true + }) + + if rel.JoinTable != nil { + var ( + joinForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) + joinForeignKeys = make([]string, 0, len(rel.References)) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(joinForeignValues) == 0 { + return + } + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) + + // convert join identity map to relation identity map + fieldValues := make([]interface{}, len(joinForeignFields)) + joinFieldValues := make([]interface{}, len(joinRelForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + for idx, field := range joinForeignFields { + fieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + } + + for idx, field := range joinRelForeignFields { + joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i)) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + joinKey := utils.ToStringKey(joinFieldValues...) + identityMap[joinKey] = append(identityMap[joinKey], results...) + } + } + + _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) + if len(foreignValues) == 0 { + return + } + } + + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) + + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } + } + + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + + fieldValues := make([]interface{}, len(relForeignFields)) + + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + } + + for i := 0; i < reflectResults.Len(); i++ { + elem := reflectResults.Index(i) + for idx, field := range relForeignFields { + fieldValues[idx], _ = field.ValueOf(elem) + } + + for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } + + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, reflectResults.Index(i).Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + } + } + } + } +} diff --git a/callbacks/query.go b/callbacks/query.go new file mode 100644 index 00000000..d0341284 --- /dev/null +++ b/callbacks/query.go @@ -0,0 +1,228 @@ +package callbacks + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func Query(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun && db.Error == nil { + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() + + gorm.Scan(rows, db, false) + } + } +} + +func BuildQuerySQL(db *gorm.DB) { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(100) + clauseSelect := clause.Select{Distinct: db.Statement.Distinct} + + if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { + var conds []clause.Expression + for _, primaryField := range db.Statement.Schema.PrimaryFields { + if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { + conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) + } + } + + if len(conds) > 0 { + db.Statement.AddClause(clause.Where{Exprs: conds}) + } + } + + if len(db.Statement.Selects) > 0 { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) + for idx, name := range db.Statement.Selects { + if db.Statement.Schema == nil { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } else if f := db.Statement.Schema.LookUpField(name); f != nil { + clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} + } else { + clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} + } + } + } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { + selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) + clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) + for _, dbName := range db.Statement.Schema.DBNames { + if v, ok := selectColumns[dbName]; (ok && v) || !ok { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) + } + } + } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { + queryFields := db.QueryFields + if !queryFields { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } + } + + if queryFields { + stmt := gorm.Statement{DB: db} + // smaller struct + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { + clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) + + for idx, dbName := range stmt.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + } + } + + // inline joins + if len(db.Statement.Joins) != 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { + clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) + for idx, dbName := range db.Statement.Schema.DBNames { + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} + } + } + + joins := []clause.Join{} + + if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + joins = fromClause.Joins + } + + for _, join := range db.Statement.Joins { + if db.Statement.Schema == nil { + joins = append(joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { + tableAliasName := relation.Name + + for _, s := range relation.FieldSchema.DBNames { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: tableAliasName + "__" + s, + }) + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } + } + } + + joins = append(joins, clause.Join{ + Type: clause.LeftJoin, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + }) + } else { + joins = append(joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } + } + + db.Statement.Joins = nil + db.Statement.AddClause(clause.From{Joins: joins}) + } else { + db.Statement.AddClauseIfNotExists(clause.From{}) + } + + db.Statement.AddClauseIfNotExists(clauseSelect) + + db.Statement.Build(db.Statement.BuildClauses...) + } +} + +func Preload(db *gorm.DB) { + if db.Error == nil && len(db.Statement.Preloads) > 0 { + preloadMap := map[string]map[string][]interface{}{} + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + if preloadFields[0] == clause.Associations { + for _, rel := range db.Statement.Schema.Relationships.Relations { + if rel.Schema == db.Statement.Schema { + if _, ok := preloadMap[rel.Name]; !ok { + preloadMap[rel.Name] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[rel.Name][value] = db.Statement.Preloads[name] + } + } + } + } else { + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] + } + } + } + + preloadNames := make([]string, 0, len(preloadMap)) + for key := range preloadMap { + preloadNames = append(preloadNames, key) + } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { + preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) + } else { + db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) + } + } + } +} + +func AfterQuery(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { + callMethod(db, func(value interface{}, tx *gorm.DB) bool { + if i, ok := value.(AfterFindInterface); ok { + db.AddError(i.AfterFind(tx)) + return true + } + return false + }) + } +} diff --git a/callbacks/raw.go b/callbacks/raw.go new file mode 100644 index 00000000..d594ab39 --- /dev/null +++ b/callbacks/raw.go @@ -0,0 +1,16 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RawExec(db *gorm.DB) { + if db.Error == nil && !db.DryRun { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() + } + } +} diff --git a/callbacks/row.go b/callbacks/row.go new file mode 100644 index 00000000..10e880e1 --- /dev/null +++ b/callbacks/row.go @@ -0,0 +1,21 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func RowQuery(db *gorm.DB) { + if db.Error == nil { + BuildQuerySQL(db) + + if !db.DryRun { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } + + db.RowsAffected = -1 + } + } +} diff --git a/callbacks/transaction.go b/callbacks/transaction.go new file mode 100644 index 00000000..8ba2ba3b --- /dev/null +++ b/callbacks/transaction.go @@ -0,0 +1,31 @@ +package callbacks + +import ( + "gorm.io/gorm" +) + +func BeginTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + db.InstanceSet("gorm:started_transaction", true) + } else if tx.Error == gorm.ErrInvalidTransaction { + tx.Error = nil + } else { + db.Error = tx.Error + } + } +} + +func CommitOrRollbackTransaction(db *gorm.DB) { + if !db.Config.SkipDefaultTransaction { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error == nil { + db.Commit() + } else { + db.Rollback() + } + db.Statement.ConnPool = db.ConnPool + } + } +} diff --git a/callbacks/update.go b/callbacks/update.go new file mode 100644 index 00000000..75bb02db --- /dev/null +++ b/callbacks/update.go @@ -0,0 +1,261 @@ +package callbacks + +import ( + "reflect" + "sort" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +func SetupUpdateReflectValue(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { + db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) + for db.Statement.ReflectValue.Kind() == reflect.Ptr { + db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() + } + + if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { + for _, rel := range db.Statement.Schema.Relationships.BelongsTo { + if _, ok := dest[rel.Name]; ok { + rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) + } + } + } + } + } +} + +func BeforeUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.BeforeSave { + if i, ok := value.(BeforeSaveInterface); ok { + called = true + db.AddError(i.BeforeSave(tx)) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(BeforeUpdateInterface); ok { + called = true + db.AddError(i.BeforeUpdate(tx)) + } + } + + return called + }) + } +} + +func Update(db *gorm.DB) { + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build(db.Statement.BuildClauses...) + } + + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } + } + } +} + +func AfterUpdate(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { + if db.Statement.Schema.AfterSave { + if i, ok := value.(AfterSaveInterface); ok { + called = true + db.AddError(i.AfterSave(tx)) + } + } + + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(AfterUpdateInterface); ok { + called = true + db.AddError(i.AfterUpdate(tx)) + } + } + return called + }) + } +} + +// ConvertToAssignments convert to update assignments +func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { + var ( + selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) + assignValue func(field *schema.Field, value interface{}) + ) + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + assignValue = func(field *schema.Field, value interface{}) { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) + } + } + case reflect.Struct: + assignValue = func(field *schema.Field, value interface{}) { + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.ReflectValue, value) + } + } + default: + assignValue = func(field *schema.Field, value interface{}) { + } + } + + updatingValue := reflect.ValueOf(stmt.Dest) + for updatingValue.Kind() == reflect.Ptr { + updatingValue = updatingValue.Elem() + } + + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var primaryKeyExprs []clause.Expression + for i := 0; i < stmt.ReflectValue.Len(); i++ { + var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) + var notZero bool + for idx, field := range stmt.Schema.PrimaryFields { + value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) + exprs[idx] = clause.Eq{Column: field.DBName, Value: value} + notZero = notZero || !isZero + } + if notZero { + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) + } + } + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) + case reflect.Struct: + for _, field := range stmt.Schema.PrimaryFields { + if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + } + + switch value := updatingValue.Interface().(type) { + case map[string]interface{}: + set = make([]clause.Assignment, 0, len(value)) + + keys := make([]string, 0, len(value)) + for k := range value { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + if field.DBName != "" { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) + assignValue(field, value[k]) + } + } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { + assignValue(field, value[k]) + } + continue + } + } + + if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) + } + } + + if !stmt.SkipHooks && stmt.Schema != nil { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { + if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { + now := stmt.DB.NowFunc() + assignValue(field, now) + + if field.AutoUpdateTime == schema.UnixNanosecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + } else if field.GORMDataType == schema.Time { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } + } + } + } + } + default: + switch updatingValue.Kind() { + case reflect.Struct: + set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) + if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false + } + + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) + } + } + } + default: + stmt.AddError(gorm.ErrInvalidData) + } + } + + return +} diff --git a/callbacks_test.go b/callbacks_test.go deleted file mode 100644 index a58913d7..00000000 --- a/callbacks_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package gorm_test - -import ( - "errors" - - "github.com/jinzhu/gorm" - - "reflect" - "testing" -) - -func (s *Product) BeforeCreate() (err error) { - if s.Code == "Invalid" { - err = errors.New("invalid product") - } - s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 - return -} - -func (s *Product) BeforeUpdate() (err error) { - if s.Code == "dont_update" { - err = errors.New("can't update") - } - s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 - return -} - -func (s *Product) BeforeSave() (err error) { - if s.Code == "dont_save" { - err = errors.New("can't save") - } - s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 - return -} - -func (s *Product) AfterFind() { - s.AfterFindCallTimes = s.AfterFindCallTimes + 1 -} - -func (s *Product) AfterCreate(tx *gorm.DB) { - tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1}) -} - -func (s *Product) AfterUpdate() { - s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 -} - -func (s *Product) AfterSave() (err error) { - if s.Code == "after_save_error" { - err = errors.New("can't save") - } - s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 - return -} - -func (s *Product) BeforeDelete() (err error) { - if s.Code == "dont_delete" { - err = errors.New("can't delete") - } - s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 - return -} - -func (s *Product) AfterDelete() (err error) { - if s.Code == "after_delete_error" { - err = errors.New("can't delete") - } - s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 - return -} - -func (s *Product) GetCallTimes() []int64 { - return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} -} - -func TestRunCallbacks(t *testing.T) { - p := Product{Code: "unique_code", Price: 100} - DB.Save(&p) - - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { - t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { - t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes()) - } - - p.Price = 200 - DB.Save(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { - t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - var products []Product - DB.Find(&products, "code = ?", "unique_code") - if products[0].AfterFindCallTimes != 2 { - t.Errorf("AfterFind callbacks should work with slice") - } - - DB.Where("Code = ?", "unique_code").First(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { - t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes()) - } - - DB.Delete(&p) - if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { - t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) - } - - if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { - t.Errorf("Can't find a deleted record") - } -} - -func TestCallbacksWithErrors(t *testing.T) { - p := Product{Code: "Invalid", Price: 100} - if DB.Save(&p).Error == nil { - t.Errorf("An error from before create callbacks happened when create with invalid value") - } - - if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { - t.Errorf("Should not save record that have errors") - } - - if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { - t.Errorf("An error from after create callbacks happened when create with invalid value") - } - - p2 := Product{Code: "update_callback", Price: 100} - DB.Save(&p2) - - p2.Code = "dont_update" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before update callbacks happened when update with invalid value") - } - - if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { - t.Errorf("Record Should not be updated due to errors happened in before update callback") - } - - p2.Code = "dont_save" - if DB.Save(&p2).Error == nil { - t.Errorf("An error from before save callbacks happened when update with invalid value") - } - - p3 := Product{Code: "dont_delete", Price: 100} - DB.Save(&p3) - if DB.Delete(&p3).Error == nil { - t.Errorf("An error from before delete callbacks happened when delete") - } - - if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { - t.Errorf("An error from before delete callbacks happened") - } - - p4 := Product{Code: "after_save_error", Price: 100} - DB.Save(&p4) - if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { - t.Errorf("Record should be reverted if get an error in after save callback") - } - - p5 := Product{Code: "after_delete_error", Price: 100} - DB.Save(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record should be found") - } - - DB.Delete(&p5) - if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { - t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback") - } -} diff --git a/chainable_api.go b/chainable_api.go new file mode 100644 index 00000000..e17d9bb2 --- /dev/null +++ b/chainable_api.go @@ -0,0 +1,293 @@ +package gorm + +import ( + "fmt" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +// Model specify the model you would like to run db operations +// // update all users's name to `hello` +// db.Model(&User{}).Update("name", "hello") +// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// db.Model(&user).Update("name", "hello") +func (db *DB) Model(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Model = value + return +} + +// Clauses Add clauses +func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { + tx = db.getInstance() + var whereConds []interface{} + + for _, cond := range conds { + if c, ok := cond.(clause.Interface); ok { + tx.Statement.AddClause(c) + } else if optimizer, ok := cond.(StatementModifier); ok { + optimizer.ModifyStatement(tx.Statement) + } else { + whereConds = append(whereConds, cond) + } + } + + if len(whereConds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) + } + return +} + +var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) + +// Table specify the table you would like to run db operations +func (db *DB) Table(name string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} + if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { + tx.Statement.Table = results[1] + return + } + } else if tables := strings.Split(name, "."); len(tables) == 2 { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = tables[1] + return + } + + tx.Statement.Table = name + return +} + +// Distinct specify distinct fields that you want querying +func (db *DB) Distinct(args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Distinct = true + if len(args) > 0 { + tx = tx.Select(args[0], args[1:]...) + } + return +} + +// Select specify fields that you want when querying, creating, updating +func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := query.(type) { + case []string: + tx.Statement.Selects = v + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + return + } + } + delete(tx.Statement.Clauses, "SELECT") + case string: + if strings.Count(v, "?") >= len(args) && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + } else if strings.Count(v, "@") > 0 && len(args) > 0 { + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.NamedExpr{SQL: v, Vars: args}, + }) + } else { + tx.Statement.Selects = []string{v} + + for _, arg := range args { + switch arg := arg.(type) { + case string: + tx.Statement.Selects = append(tx.Statement.Selects, arg) + case []string: + tx.Statement.Selects = append(tx.Statement.Selects, arg...) + default: + tx.Statement.AddClause(clause.Select{ + Distinct: db.Statement.Distinct, + Expression: clause.Expr{SQL: v, Vars: args}, + }) + return + } + } + + delete(tx.Statement.Clauses, "SELECT") + } + default: + tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + } + + return +} + +// Omit specify fields that you want to ignore when creating, updating and querying +func (db *DB) Omit(columns ...string) (tx *DB) { + tx = db.getInstance() + + if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + } else { + tx.Statement.Omits = columns + } + return +} + +// Where add conditions +func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: conds}) + } + return +} + +// Not add NOT conditions +func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + } + return +} + +// Or add OR conditions +func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) + } + return +} + +// Joins specify Joins conditions +// db.Joins("Account").Find(&user) +// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + return +} + +// Group specify the group method on the find +func (db *DB) Group(name string) (tx *DB) { + tx = db.getInstance() + + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) + tx.Statement.AddClause(clause.GroupBy{ + Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, + }) + return +} + +// Having specify HAVING conditions for GROUP BY +func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.GroupBy{ + Having: tx.Statement.BuildCondition(query, args...), + }) + return +} + +// Order specify order when retrieve records from database +// db.Order("name DESC") +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +func (db *DB) Order(value interface{}) (tx *DB) { + tx = db.getInstance() + + switch v := value.(type) { + case clause.OrderByColumn: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{v}, + }) + default: + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, + }}, + }) + } + return +} + +// Limit specify the number of records to be retrieved +func (db *DB) Limit(limit int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Limit: limit}) + return +} + +// Offset specify the number of records to skip before starting to return the records +func (db *DB) Offset(offset int) (tx *DB) { + tx = db.getInstance() + tx.Statement.AddClause(clause.Limit{Offset: offset}) + return +} + +// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically +// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { +// return db.Where("amount > ?", 1000) +// } +// +// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { +// return func (db *gorm.DB) *gorm.DB { +// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) +// } +// } +// +// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) +func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { + tx = db.getInstance() + tx.Statement.scopes = append(tx.Statement.scopes, funcs...) + return tx +} + +// Preload preload associations with given conditions +// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) +func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Preloads == nil { + tx.Statement.Preloads = map[string][]interface{}{} + } + tx.Statement.Preloads[query] = args + return +} + +func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.attrs = attrs + return +} + +func (db *DB) Assign(attrs ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.assigns = attrs + return +} + +func (db *DB) Unscoped() (tx *DB) { + tx = db.getInstance() + tx.Statement.Unscoped = true + return +} + +func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + return +} diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go new file mode 100644 index 00000000..88a238e3 --- /dev/null +++ b/clause/benchmarks_test.go @@ -0,0 +1,56 @@ +package clause_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func BenchmarkSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}} + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE") + _ = stmt.SQL.String() + } +} + +func BenchmarkComplexSelect(b *testing.B) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + + for i := 0; i < b.N; i++ { + stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clauses := []clause.Interface{ + clause.Select{}, clause.From{}, + clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, + clause.Gt{Column: "age", Value: 18}, + clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), + }}, + clause.Where{Exprs: []clause.Expression{ + clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), + }}, + clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, + clause.Limit{Limit: 10, Offset: 20}, + clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, + } + + for _, clause := range clauses { + stmt.AddClause(clause) + } + + stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY") + _ = stmt.SQL.String() + } +} diff --git a/clause/clause.go b/clause/clause.go new file mode 100644 index 00000000..de19f2e3 --- /dev/null +++ b/clause/clause.go @@ -0,0 +1,88 @@ +package clause + +// Interface clause interface +type Interface interface { + Name() string + Build(Builder) + MergeClause(*Clause) +} + +// ClauseBuilder clause builder, allows to customize how to build clause +type ClauseBuilder func(Clause, Builder) + +type Writer interface { + WriteByte(byte) error + WriteString(string) (int, error) +} + +// Builder builder interface +type Builder interface { + Writer + WriteQuoted(field interface{}) + AddVar(Writer, ...interface{}) +} + +// Clause +type Clause struct { + Name string // WHERE + BeforeExpression Expression + AfterNameExpression Expression + AfterExpression Expression + Expression Expression + Builder ClauseBuilder +} + +// Build build clause +func (c Clause) Build(builder Builder) { + if c.Builder != nil { + c.Builder(c, builder) + } else if c.Expression != nil { + if c.BeforeExpression != nil { + c.BeforeExpression.Build(builder) + builder.WriteByte(' ') + } + + if c.Name != "" { + builder.WriteString(c.Name) + builder.WriteByte(' ') + } + + if c.AfterNameExpression != nil { + c.AfterNameExpression.Build(builder) + builder.WriteByte(' ') + } + + c.Expression.Build(builder) + + if c.AfterExpression != nil { + builder.WriteByte(' ') + c.AfterExpression.Build(builder) + } + } +} + +const ( + PrimaryKey string = "~~~py~~~" // primary key + CurrentTable string = "~~~ct~~~" // current table + Associations string = "~~~as~~~" // associations +) + +var ( + currentTable = Table{Name: CurrentTable} + PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} +) + +// Column quote with name +type Column struct { + Table string + Name string + Alias string + Raw bool +} + +// Table quote with name +type Table struct { + Name string + Alias string + Raw bool +} diff --git a/clause/clause_test.go b/clause/clause_test.go new file mode 100644 index 00000000..6239ff39 --- /dev/null +++ b/clause/clause_test.go @@ -0,0 +1,43 @@ +package clause_test + +import ( + "reflect" + "strings" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +var db, _ = gorm.Open(tests.DummyDialector{}, nil) + +func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { + var ( + buildNames []string + buildNamesMap = map[string]bool{} + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + ) + + for _, c := range clauses { + if _, ok := buildNamesMap[c.Name()]; !ok { + buildNames = append(buildNames, c.Name()) + buildNamesMap[c.Name()] = true + } + + stmt.AddClause(c) + } + + stmt.Build(buildNames...) + + if strings.TrimSpace(stmt.SQL.String()) != result { + t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(stmt.Vars, vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) + } +} diff --git a/clause/delete.go b/clause/delete.go new file mode 100644 index 00000000..fc462cd7 --- /dev/null +++ b/clause/delete.go @@ -0,0 +1,23 @@ +package clause + +type Delete struct { + Modifier string +} + +func (d Delete) Name() string { + return "DELETE" +} + +func (d Delete) Build(builder Builder) { + builder.WriteString("DELETE") + + if d.Modifier != "" { + builder.WriteByte(' ') + builder.WriteString(d.Modifier) + } +} + +func (d Delete) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = d +} diff --git a/clause/delete_test.go b/clause/delete_test.go new file mode 100644 index 00000000..a9a659b3 --- /dev/null +++ b/clause/delete_test.go @@ -0,0 +1,31 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestDelete(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Delete{}, clause.From{}}, + "DELETE FROM `users`", nil, + }, + { + []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, + "DELETE LOW_PRIORITY FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/expression.go b/clause/expression.go new file mode 100644 index 00000000..f76ce138 --- /dev/null +++ b/clause/expression.go @@ -0,0 +1,344 @@ +package clause + +import ( + "database/sql" + "database/sql/driver" + "go/ast" + "reflect" +) + +// Expression expression interface +type Expression interface { + Build(builder Builder) +} + +// NegationExpressionBuilder negation expression builder +type NegationExpressionBuilder interface { + NegationBuild(builder Builder) +} + +// Expr raw expression +type Expr struct { + SQL string + Vars []interface{} + WithoutParentheses bool +} + +// Build build raw expression +func (expr Expr) Build(builder Builder) { + var ( + afterParenthesis bool + idx int + ) + + for _, v := range []byte(expr.SQL) { + if v == '?' && len(expr.Vars) > idx { + if afterParenthesis || expr.WithoutParentheses { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } +} + +// NamedExpr raw expression for named expr +type NamedExpr struct { + SQL string + Vars []interface{} +} + +// Build build raw expression +func (expr NamedExpr) Build(builder Builder) { + var ( + idx int + inName bool + afterParenthesis bool + namedMap = make(map[string]interface{}, len(expr.Vars)) + ) + + for _, v := range expr.Vars { + switch value := v.(type) { + case sql.NamedArg: + namedMap[value.Name] = value.Value + case map[string]interface{}: + for k, v := range value { + namedMap[k] = v + } + default: + var appendFieldsToMap func(reflect.Value) + appendFieldsToMap = func(reflectValue reflect.Value) { + reflectValue = reflect.Indirect(reflectValue) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + + if fieldStruct.Anonymous { + appendFieldsToMap(reflectValue.Field(i)) + } + } + } + } + } + + appendFieldsToMap(reflect.ValueOf(value)) + } + } + + name := make([]byte, 0, 10) + + for _, v := range []byte(expr.SQL) { + if v == '@' && !inName { + inName = true + name = []byte{} + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { + if inName { + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } + inName = false + } + + afterParenthesis = false + builder.WriteByte(v) + } else if v == '?' && len(expr.Vars) > idx { + if afterParenthesis { + if _, ok := expr.Vars[idx].(driver.Valuer); ok { + builder.AddVar(builder, expr.Vars[idx]) + } else { + switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + } + default: + builder.AddVar(builder, expr.Vars[idx]) + } + } + } else { + builder.AddVar(builder, expr.Vars[idx]) + } + + idx++ + } else if inName { + name = append(name, v) + } else { + if v == '(' { + afterParenthesis = true + } else { + afterParenthesis = false + } + builder.WriteByte(v) + } + } + + if inName { + builder.AddVar(builder, namedMap[string(name)]) + } +} + +// IN Whether a value is within a set of values +type IN struct { + Column interface{} + Values []interface{} +} + +func (in IN) Build(builder Builder) { + builder.WriteQuoted(in.Column) + + switch len(in.Values) { + case 0: + builder.WriteString(" IN (NULL)") + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteString(" IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +func (in IN) NegationBuild(builder Builder) { + switch len(in.Values) { + case 0: + case 1: + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteQuoted(in.Column) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough + default: + builder.WriteQuoted(in.Column) + builder.WriteString(" NOT IN (") + builder.AddVar(builder, in.Values...) + builder.WriteByte(')') + } +} + +// Eq equal to for where +type Eq struct { + Column interface{} + Value interface{} +} + +func (eq Eq) Build(builder Builder) { + builder.WriteQuoted(eq.Column) + + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } +} + +func (eq Eq) NegationBuild(builder Builder) { + Neq(eq).Build(builder) +} + +// Neq not equal to for where +type Neq Eq + +func (neq Neq) Build(builder Builder) { + builder.WriteQuoted(neq.Column) + + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } +} + +func (neq Neq) NegationBuild(builder Builder) { + Eq(neq).Build(builder) +} + +// Gt greater than for where +type Gt Eq + +func (gt Gt) Build(builder Builder) { + builder.WriteQuoted(gt.Column) + builder.WriteString(" > ") + builder.AddVar(builder, gt.Value) +} + +func (gt Gt) NegationBuild(builder Builder) { + Lte(gt).Build(builder) +} + +// Gte greater than or equal to for where +type Gte Eq + +func (gte Gte) Build(builder Builder) { + builder.WriteQuoted(gte.Column) + builder.WriteString(" >= ") + builder.AddVar(builder, gte.Value) +} + +func (gte Gte) NegationBuild(builder Builder) { + Lt(gte).Build(builder) +} + +// Lt less than for where +type Lt Eq + +func (lt Lt) Build(builder Builder) { + builder.WriteQuoted(lt.Column) + builder.WriteString(" < ") + builder.AddVar(builder, lt.Value) +} + +func (lt Lt) NegationBuild(builder Builder) { + Gte(lt).Build(builder) +} + +// Lte less than or equal to for where +type Lte Eq + +func (lte Lte) Build(builder Builder) { + builder.WriteQuoted(lte.Column) + builder.WriteString(" <= ") + builder.AddVar(builder, lte.Value) +} + +func (lte Lte) NegationBuild(builder Builder) { + Gt(lte).Build(builder) +} + +// Like whether string matches regular expression +type Like Eq + +func (like Like) Build(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" LIKE ") + builder.AddVar(builder, like.Value) +} + +func (like Like) NegationBuild(builder Builder) { + builder.WriteQuoted(like.Column) + builder.WriteString(" NOT LIKE ") + builder.AddVar(builder, like.Value) +} + +func eqNil(value interface{}) bool { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + + return value == nil || eqNilReflect(value) +} + +func eqNilReflect(value interface{}) bool { + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} diff --git a/clause/expression_test.go b/clause/expression_test.go new file mode 100644 index 00000000..4472bdb1 --- /dev/null +++ b/clause/expression_test.go @@ -0,0 +1,153 @@ +package clause_test + +import ( + "database/sql" + "fmt" + "reflect" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestExpr(t *testing.T) { + results := []struct { + SQL string + Result string + Vars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } +} + +func TestNamedExpr(t *testing.T) { + type Base struct { + Name2 string + } + + type NamedArgument struct { + Name1 string + Base + } + + results := []struct { + SQL string + Result string + Vars []interface{} + ExpectedVars []interface{} + }{{ + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, + Result: "create table `users` (`id` int, `name` text)", + }, { + SQL: "name1 = @name AND name2 = @name", + Vars: []interface{}{sql.Named("name", "jinzhu")}, + Result: "name1 = ? AND name2 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", + Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, + Result: "name1 = ? AND name2 = ? AND name3 = ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, + }, { + SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", + Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", + }} + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } + }) + } +} + +func TestExpression(t *testing.T) { + column := "column-name" + results := []struct { + Expressions []clause.Expression + Result string + }{{ + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` = ?", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: nil}, + clause.Eq{Column: column, Value: (*string)(nil)}, + clause.Eq{Column: column, Value: (*int)(nil)}, + clause.Eq{Column: column, Value: (*bool)(nil)}, + clause.Eq{Column: column, Value: (interface{})(nil)}, + clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}}, + }, + Result: "`column-name` IS NULL", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: "column-value"}, + }, + Result: "`column-name` <> ?", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: nil}, + clause.Neq{Column: column, Value: (*string)(nil)}, + clause.Neq{Column: column, Value: (*int)(nil)}, + clause.Neq{Column: column, Value: (*bool)(nil)}, + clause.Neq{Column: column, Value: (interface{})(nil)}, + }, + Result: "`column-name` IS NOT NULL", + }} + + for idx, result := range results { + for idy, expression := range result.Expressions { + t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { + user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} + expression.Build(stmt) + if stmt.SQL.String() != result.Result { + t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) + } + }) + } + } +} diff --git a/clause/from.go b/clause/from.go new file mode 100644 index 00000000..1ea2d595 --- /dev/null +++ b/clause/from.go @@ -0,0 +1,37 @@ +package clause + +// From from clause +type From struct { + Tables []Table + Joins []Join +} + +// Name from clause name +func (from From) Name() string { + return "FROM" +} + +// Build build from clause +func (from From) Build(builder Builder) { + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) + } + + for _, join := range from.Joins { + builder.WriteByte(' ') + join.Build(builder) + } +} + +// MergeClause merge from clause +func (from From) MergeClause(clause *Clause) { + clause.Expression = from +} diff --git a/clause/from_test.go b/clause/from_test.go new file mode 100644 index 00000000..75422f8e --- /dev/null +++ b/clause/from_test.go @@ -0,0 +1,75 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestFrom(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Type: clause.RightJoin, + Table: clause.Table{Name: "profiles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, + }, + }, + }, + }, clause.From{ + Joins: []clause.Join{ + { + Type: clause.InnerJoin, + Table: clause.Table{Name: "articles"}, + ON: clause.Where{ + []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, + }, + }, { + Type: clause.LeftJoin, + Table: clause.Table{Name: "companies"}, + Using: []string{"company_name"}, + }, + }, + }, + }, + "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/group_by.go b/clause/group_by.go new file mode 100644 index 00000000..84242fb8 --- /dev/null +++ b/clause/group_by.go @@ -0,0 +1,48 @@ +package clause + +// GroupBy group by clause +type GroupBy struct { + Columns []Column + Having []Expression +} + +// Name from clause name +func (groupBy GroupBy) Name() string { + return "GROUP BY" +} + +// Build build group by clause +func (groupBy GroupBy) Build(builder Builder) { + for idx, column := range groupBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } + + if len(groupBy.Having) > 0 { + builder.WriteString(" HAVING ") + Where{Exprs: groupBy.Having}.Build(builder) + } +} + +// MergeClause merge group by clause +func (groupBy GroupBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(GroupBy); ok { + copiedColumns := make([]Column, len(v.Columns)) + copy(copiedColumns, v.Columns) + groupBy.Columns = append(copiedColumns, groupBy.Columns...) + + copiedHaving := make([]Expression, len(v.Having)) + copy(copiedHaving, v.Having) + groupBy.Having = append(copiedHaving, groupBy.Having...) + } + clause.Expression = groupBy + + if len(groupBy.Columns) == 0 { + clause.Name = "" + } else { + clause.Name = groupBy.Name() + } +} diff --git a/clause/group_by_test.go b/clause/group_by_test.go new file mode 100644 index 00000000..589f9613 --- /dev/null +++ b/clause/group_by_test.go @@ -0,0 +1,40 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestGroupBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, + }}, + "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ + Columns: []clause.Column{{Name: "role"}}, + Having: []clause.Expression{clause.Eq{"role", "admin"}}, + }, clause.GroupBy{ + Columns: []clause.Column{{Name: "gender"}}, + Having: []clause.Expression{clause.Neq{"gender", "U"}}, + }}, + "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/insert.go b/clause/insert.go new file mode 100644 index 00000000..8efaa035 --- /dev/null +++ b/clause/insert.go @@ -0,0 +1,39 @@ +package clause + +type Insert struct { + Table Table + Modifier string +} + +// Name insert clause name +func (insert Insert) Name() string { + return "INSERT" +} + +// Build build insert clause +func (insert Insert) Build(builder Builder) { + if insert.Modifier != "" { + builder.WriteString(insert.Modifier) + builder.WriteByte(' ') + } + + builder.WriteString("INTO ") + if insert.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(insert.Table) + } +} + +// MergeClause merge insert clause +func (insert Insert) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Insert); ok { + if insert.Modifier == "" { + insert.Modifier = v.Modifier + } + if insert.Table.Name == "" { + insert.Table = v.Table + } + } + clause.Expression = insert +} diff --git a/clause/insert_test.go b/clause/insert_test.go new file mode 100644 index 00000000..70810bce --- /dev/null +++ b/clause/insert_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestInsert(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Insert{}}, + "INSERT INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `users`", nil, + }, + { + []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "INSERT LOW_PRIORITY INTO `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/joins.go b/clause/joins.go new file mode 100644 index 00000000..f3e373f2 --- /dev/null +++ b/clause/joins.go @@ -0,0 +1,47 @@ +package clause + +type JoinType string + +const ( + CrossJoin JoinType = "CROSS" + InnerJoin JoinType = "INNER" + LeftJoin JoinType = "LEFT" + RightJoin JoinType = "RIGHT" +) + +// Join join clause for from +type Join struct { + Type JoinType + Table Table + ON Where + Using []string + Expression Expression +} + +func (join Join) Build(builder Builder) { + if join.Expression != nil { + join.Expression.Build(builder) + } else { + if join.Type != "" { + builder.WriteString(string(join.Type)) + builder.WriteByte(' ') + } + + builder.WriteString("JOIN ") + builder.WriteQuoted(join.Table) + + if len(join.ON.Exprs) > 0 { + builder.WriteString(" ON ") + join.ON.Build(builder) + } else if len(join.Using) > 0 { + builder.WriteString(" USING (") + for idx, c := range join.Using { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(c) + } + builder.WriteByte(')') + } + } +} diff --git a/clause/limit.go b/clause/limit.go new file mode 100644 index 00000000..2082f4d9 --- /dev/null +++ b/clause/limit.go @@ -0,0 +1,48 @@ +package clause + +import "strconv" + +// Limit limit clause +type Limit struct { + Limit int + Offset int +} + +// Name where clause name +func (limit Limit) Name() string { + return "LIMIT" +} + +// Build build where clause +func (limit Limit) Build(builder Builder) { + if limit.Limit > 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(limit.Limit)) + } + if limit.Offset > 0 { + if limit.Limit > 0 { + builder.WriteString(" ") + } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } +} + +// MergeClause merge order by clauses +func (limit Limit) MergeClause(clause *Clause) { + clause.Name = "" + + if v, ok := clause.Expression.(Limit); ok { + if limit.Limit == 0 && v.Limit != 0 { + limit.Limit = v.Limit + } + + if limit.Offset == 0 && v.Offset > 0 { + limit.Offset = v.Offset + } else if limit.Offset < 0 { + limit.Offset = 0 + } + } + + clause.Expression = limit +} diff --git a/clause/limit_test.go b/clause/limit_test.go new file mode 100644 index 00000000..c26294aa --- /dev/null +++ b/clause/limit_test.go @@ -0,0 +1,58 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestLimit(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ + Limit: 10, + Offset: 20, + }}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, + "SELECT * FROM `users` OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, + "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, + "SELECT * FROM `users` LIMIT 10", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, + "SELECT * FROM `users` OFFSET 30", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, + "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/locking.go b/clause/locking.go new file mode 100644 index 00000000..290aac92 --- /dev/null +++ b/clause/locking.go @@ -0,0 +1,31 @@ +package clause + +type Locking struct { + Strength string + Table Table + Options string +} + +// Name where clause name +func (locking Locking) Name() string { + return "FOR" +} + +// Build build where clause +func (locking Locking) Build(builder Builder) { + builder.WriteString(locking.Strength) + if locking.Table.Name != "" { + builder.WriteString(" OF ") + builder.WriteQuoted(locking.Table) + } + + if locking.Options != "" { + builder.WriteByte(' ') + builder.WriteString(locking.Options) + } +} + +// MergeClause merge order by clauses +func (locking Locking) MergeClause(clause *Clause) { + clause.Expression = locking +} diff --git a/clause/locking_test.go b/clause/locking_test.go new file mode 100644 index 00000000..0e607312 --- /dev/null +++ b/clause/locking_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestLocking(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}}, + "SELECT * FROM `users` FOR UPDATE", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}}, + "SELECT * FROM `users` FOR SHARE OF `users`", nil, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}}, + "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/on_conflict.go b/clause/on_conflict.go new file mode 100644 index 00000000..127d9bc1 --- /dev/null +++ b/clause/on_conflict.go @@ -0,0 +1,52 @@ +package clause + +type OnConflict struct { + Columns []Column + Where Where + OnConstraint string + DoNothing bool + DoUpdates Set + UpdateAll bool +} + +func (OnConflict) Name() string { + return "ON CONFLICT" +} + +// Build build onConflict clause +func (onConflict OnConflict) Build(builder Builder) { + if len(onConflict.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range onConflict.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteString(`) `) + } + + if onConflict.OnConstraint != "" { + builder.WriteString("ON CONSTRAINT ") + builder.WriteString(onConflict.OnConstraint) + builder.WriteByte(' ') + } + + if onConflict.DoNothing { + builder.WriteString("DO NOTHING") + } else { + builder.WriteString("DO UPDATE SET ") + onConflict.DoUpdates.Build(builder) + } + + if len(onConflict.Where.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.Where.Build(builder) + builder.WriteByte(' ') + } +} + +// MergeClause merge onConflict clauses +func (onConflict OnConflict) MergeClause(clause *Clause) { + clause.Expression = onConflict +} diff --git a/clause/order_by.go b/clause/order_by.go new file mode 100644 index 00000000..41218025 --- /dev/null +++ b/clause/order_by.go @@ -0,0 +1,54 @@ +package clause + +type OrderByColumn struct { + Column Column + Desc bool + Reorder bool +} + +type OrderBy struct { + Columns []OrderByColumn + Expression Expression +} + +// Name where clause name +func (orderBy OrderBy) Name() string { + return "ORDER BY" +} + +// Build build where clause +func (orderBy OrderBy) Build(builder Builder) { + if orderBy.Expression != nil { + orderBy.Expression.Build(builder) + } else { + for idx, column := range orderBy.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column.Column) + if column.Desc { + builder.WriteString(" DESC") + } + } + } +} + +// MergeClause merge order by clauses +func (orderBy OrderBy) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(OrderBy); ok { + for i := len(orderBy.Columns) - 1; i >= 0; i-- { + if orderBy.Columns[i].Reorder { + orderBy.Columns = orderBy.Columns[i:] + clause.Expression = orderBy + return + } + } + + copiedColumns := make([]OrderByColumn, len(v.Columns)) + copy(copiedColumns, v.Columns) + orderBy.Columns = append(copiedColumns, orderBy.Columns...) + } + + clause.Expression = orderBy +} diff --git a/clause/order_by_test.go b/clause/order_by_test.go new file mode 100644 index 00000000..8fd1e2a8 --- /dev/null +++ b/clause/order_by_test.go @@ -0,0 +1,57 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestOrderBy(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }}, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, + }, + }, + "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, + }, clause.OrderBy{ + Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, + }, + }, + "SELECT * FROM `users` ORDER BY `name`", nil, + }, + { + []clause.Interface{ + clause.Select{}, clause.From{}, clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }, + }, + "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/returning.go b/clause/returning.go new file mode 100644 index 00000000..04bc96da --- /dev/null +++ b/clause/returning.go @@ -0,0 +1,30 @@ +package clause + +type Returning struct { + Columns []Column +} + +// Name where clause name +func (returning Returning) Name() string { + return "RETURNING" +} + +// Build build where clause +func (returning Returning) Build(builder Builder) { + for idx, column := range returning.Columns { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteQuoted(column) + } +} + +// MergeClause merge order by clauses +func (returning Returning) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Returning); ok { + returning.Columns = append(v.Columns, returning.Columns...) + } + + clause.Expression = returning +} diff --git a/clause/returning_test.go b/clause/returning_test.go new file mode 100644 index 00000000..bd0ecce8 --- /dev/null +++ b/clause/returning_test.go @@ -0,0 +1,36 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestReturning(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`", nil, + }, { + []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ + []clause.Column{clause.PrimaryColumn}, + }, clause.Returning{ + []clause.Column{{Name: "name"}, {Name: "age"}}, + }}, + "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/select.go b/clause/select.go new file mode 100644 index 00000000..b93b8769 --- /dev/null +++ b/clause/select.go @@ -0,0 +1,45 @@ +package clause + +// Select select attrs when querying, updating, creating +type Select struct { + Distinct bool + Columns []Column + Expression Expression +} + +func (s Select) Name() string { + return "SELECT" +} + +func (s Select) Build(builder Builder) { + if len(s.Columns) > 0 { + if s.Distinct { + builder.WriteString("DISTINCT ") + } + + for idx, column := range s.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + } else { + builder.WriteByte('*') + } +} + +func (s Select) MergeClause(clause *Clause) { + if s.Expression != nil { + if s.Distinct { + if expr, ok := s.Expression.(Expr); ok { + expr.SQL = "DISTINCT " + expr.SQL + clause.Expression = expr + return + } + } + + clause.Expression = s.Expression + } else { + clause.Expression = s + } +} diff --git a/clause/select_test.go b/clause/select_test.go new file mode 100644 index 00000000..b7296434 --- /dev/null +++ b/clause/select_test.go @@ -0,0 +1,41 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestSelect(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.From{}}, + "SELECT `users`.`id` FROM `users`", nil, + }, + { + []clause.Interface{clause.Select{ + Columns: []clause.Column{clause.PrimaryColumn}, + }, clause.Select{ + Columns: []clause.Column{{Name: "name"}}, + }, clause.From{}}, + "SELECT `name` FROM `users`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/set.go b/clause/set.go new file mode 100644 index 00000000..6a885711 --- /dev/null +++ b/clause/set.go @@ -0,0 +1,60 @@ +package clause + +import "sort" + +type Set []Assignment + +type Assignment struct { + Column Column + Value interface{} +} + +func (set Set) Name() string { + return "SET" +} + +func (set Set) Build(builder Builder) { + if len(set) > 0 { + for idx, assignment := range set { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(assignment.Column) + builder.WriteByte('=') + builder.AddVar(builder, assignment.Value) + } + } else { + builder.WriteQuoted(PrimaryColumn) + builder.WriteByte('=') + builder.WriteQuoted(PrimaryColumn) + } +} + +// MergeClause merge assignments clauses +func (set Set) MergeClause(clause *Clause) { + copiedAssignments := make([]Assignment, len(set)) + copy(copiedAssignments, set) + clause.Expression = Set(copiedAssignments) +} + +func Assignments(values map[string]interface{}) Set { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + sort.Strings(keys) + + assignments := make([]Assignment, len(keys)) + for idx, key := range keys { + assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} + } + return assignments +} + +func AssignmentColumns(values []string) Set { + assignments := make([]Assignment, len(values)) + for idx, value := range values { + assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} + } + return assignments +} diff --git a/clause/set_test.go b/clause/set_test.go new file mode 100644 index 00000000..56fac706 --- /dev/null +++ b/clause/set_test.go @@ -0,0 +1,57 @@ +package clause_test + +import ( + "fmt" + "sort" + "strings" + "testing" + + "gorm.io/gorm/clause" +) + +func TestSet(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + }, + "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, + }, + { + []clause.Interface{ + clause.Update{}, + clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), + clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), + }, + "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} + +func TestAssignments(t *testing.T) { + set := clause.Assignments(map[string]interface{}{ + "name": "jinzhu", + "age": 18, + }) + + assignments := []clause.Assignment(set) + + sort.Slice(assignments, func(i, j int) bool { + return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 + }) + + if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { + t.Errorf("invalid assignments, got %v", assignments) + } +} diff --git a/clause/update.go b/clause/update.go new file mode 100644 index 00000000..f9d68ac6 --- /dev/null +++ b/clause/update.go @@ -0,0 +1,38 @@ +package clause + +type Update struct { + Modifier string + Table Table +} + +// Name update clause name +func (update Update) Name() string { + return "UPDATE" +} + +// Build build update clause +func (update Update) Build(builder Builder) { + if update.Modifier != "" { + builder.WriteString(update.Modifier) + builder.WriteByte(' ') + } + + if update.Table.Name == "" { + builder.WriteQuoted(currentTable) + } else { + builder.WriteQuoted(update.Table) + } +} + +// MergeClause merge update clause +func (update Update) MergeClause(clause *Clause) { + if v, ok := clause.Expression.(Update); ok { + if update.Modifier == "" { + update.Modifier = v.Modifier + } + if update.Table.Name == "" { + update.Table = v.Table + } + } + clause.Expression = update +} diff --git a/clause/update_test.go b/clause/update_test.go new file mode 100644 index 00000000..c704bf5e --- /dev/null +++ b/clause/update_test.go @@ -0,0 +1,35 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestUpdate(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Update{}}, + "UPDATE `users`", nil, + }, + { + []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `users`", nil, + }, + { + []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, + "UPDATE LOW_PRIORITY `products`", nil, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/values.go b/clause/values.go new file mode 100644 index 00000000..b2f5421b --- /dev/null +++ b/clause/values.go @@ -0,0 +1,45 @@ +package clause + +type Values struct { + Columns []Column + Values [][]interface{} +} + +// Name from clause name +func (Values) Name() string { + return "VALUES" +} + +// Build build from clause +func (values Values) Build(builder Builder) { + if len(values.Columns) > 0 { + builder.WriteByte('(') + for idx, column := range values.Columns { + if idx > 0 { + builder.WriteByte(',') + } + builder.WriteQuoted(column) + } + builder.WriteByte(')') + + builder.WriteString(" VALUES ") + + for idx, value := range values.Values { + if idx > 0 { + builder.WriteByte(',') + } + + builder.WriteByte('(') + builder.AddVar(builder, value...) + builder.WriteByte(')') + } + } else { + builder.WriteString("DEFAULT VALUES") + } +} + +// MergeClause merge values clauses +func (values Values) MergeClause(clause *Clause) { + clause.Name = "" + clause.Expression = values +} diff --git a/clause/values_test.go b/clause/values_test.go new file mode 100644 index 00000000..9c02c8a5 --- /dev/null +++ b/clause/values_test.go @@ -0,0 +1,33 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestValues(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{ + clause.Insert{}, + clause.Values{ + Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, + Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, + }, + }, + "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/where.go b/clause/where.go new file mode 100644 index 00000000..00b1a40e --- /dev/null +++ b/clause/where.go @@ -0,0 +1,177 @@ +package clause + +import ( + "strings" +) + +// Where where clause +type Where struct { + Exprs []Expression +} + +// Name where clause name +func (where Where) Name() string { + return "WHERE" +} + +// Build build where clause +func (where Where) Build(builder Builder) { + // Switch position if the first query expression is a single Or condition + for idx, expr := range where.Exprs { + if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { + if idx != 0 { + where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] + } + break + } + } + + buildExprs(where.Exprs, builder, " AND ") +} + +func buildExprs(exprs []Expression, builder Builder, joinCond string) { + wrapInParentheses := false + + for idx, expr := range exprs { + if idx > 0 { + if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { + builder.WriteString(" OR ") + } else { + builder.WriteString(joinCond) + } + } + + if len(exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case Expr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + + if wrapInParentheses { + builder.WriteString(`(`) + expr.Build(builder) + builder.WriteString(`)`) + wrapInParentheses = false + } else { + expr.Build(builder) + } + } +} + +// MergeClause merge where clauses +func (where Where) MergeClause(clause *Clause) { + if w, ok := clause.Expression.(Where); ok { + exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) + copy(exprs, w.Exprs) + copy(exprs[len(w.Exprs):], where.Exprs) + where.Exprs = exprs + } + + clause.Expression = where +} + +func And(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } else if len(exprs) == 1 { + return exprs[0] + } + return AndConditions{Exprs: exprs} +} + +type AndConditions struct { + Exprs []Expression +} + +func (and AndConditions) Build(builder Builder) { + if len(and.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(and.Exprs, builder, " AND ") + builder.WriteByte(')') + } else { + buildExprs(and.Exprs, builder, " AND ") + } +} + +func Or(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return OrConditions{Exprs: exprs} +} + +type OrConditions struct { + Exprs []Expression +} + +func (or OrConditions) Build(builder Builder) { + if len(or.Exprs) > 1 { + builder.WriteByte('(') + buildExprs(or.Exprs, builder, " OR ") + builder.WriteByte(')') + } else { + buildExprs(or.Exprs, builder, " OR ") + } +} + +func Not(exprs ...Expression) Expression { + if len(exprs) == 0 { + return nil + } + return NotConditions{Exprs: exprs} +} + +type NotConditions struct { + Exprs []Expression +} + +func (not NotConditions) Build(builder Builder) { + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(" AND ") + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToLower(e.SQL) + if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } +} diff --git a/clause/where_test.go b/clause/where_test.go new file mode 100644 index 00000000..2fa11d76 --- /dev/null +++ b/clause/where_test.go @@ -0,0 +1,69 @@ +package clause_test + +import ( + "fmt" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhere(t *testing.T) { + results := []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, + }, clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, + }}, + "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, + }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, + }}, + "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"}, + }, + } + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + checkBuildClauses(t, result.Clauses, result.Result, result.Vars) + }) + } +} diff --git a/clause/with.go b/clause/with.go new file mode 100644 index 00000000..7e9eaef1 --- /dev/null +++ b/clause/with.go @@ -0,0 +1,4 @@ +package clause + +type With struct { +} diff --git a/create_test.go b/create_test.go deleted file mode 100644 index 36472914..00000000 --- a/create_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "testing" - "time" - - "github.com/jinzhu/now" -) - -func TestCreate(t *testing.T) { - float := 35.03554004971999 - now := time.Now() - user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float} - - if !DB.NewRecord(user) || !DB.NewRecord(&user) { - t.Error("User should be new record before create") - } - - if count := DB.Save(&user).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - if DB.NewRecord(user) || DB.NewRecord(&user) { - t.Error("User should not new record after save") - } - - var newUser User - DB.First(&newUser, user.Id) - - if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) { - t.Errorf("User's PasswordHash should be saved ([]byte)") - } - - if newUser.Age != 18 { - t.Errorf("User's Age should be saved (int)") - } - - if newUser.UserNum != Num(111) { - t.Errorf("User's UserNum should be saved (custom type)") - } - - if newUser.Latitude != float { - t.Errorf("Float64 should not be changed after save") - } - - if user.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - if newUser.CreatedAt.IsZero() { - t.Errorf("Should have created_at after create") - } - - DB.Model(user).Update("name", "create_user_new_name") - DB.First(&user, user.Id) - if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) { - t.Errorf("CreatedAt should not be changed after update") - } -} - -func TestCreateWithExistingTimestamp(t *testing.T) { - user := User{Name: "CreateUserExistingTimestamp"} - - timeA := now.MustParse("2016-01-01") - user.CreatedAt = timeA - user.UpdatedAt = timeA - DB.Save(&user) - - if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } - - var newUser User - DB.First(&newUser, user.Id) - - if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("CreatedAt should not be changed") - } - - if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) { - t.Errorf("UpdatedAt should not be changed") - } -} - -type AutoIncrementUser struct { - User - Sequence uint `gorm:"AUTO_INCREMENT"` -} - -func TestCreateWithAutoIncrement(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column") - } - - DB.AutoMigrate(&AutoIncrementUser{}) - - user1 := AutoIncrementUser{} - user2 := AutoIncrementUser{} - - DB.Create(&user1) - DB.Create(&user2) - - if user2.Sequence-user1.Sequence != 1 { - t.Errorf("Auto increment should apply on Sequence") - } -} - -func TestCreateWithNoGORMPrimayKey(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { - t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column") - } - - jt := JoinTable{From: 1, To: 2} - err := DB.Create(&jt).Error - if err != nil { - t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) - } -} - -func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - if DB.Save(&animal).Error != nil { - t.Errorf("No error should happen when create a record without std primary key") - } - - if animal.Counter == 0 { - t.Errorf("No std primary key should be filled value after create") - } - - if animal.Name != "Ferdinand" { - t.Errorf("Default value should be overrided") - } - - // Test create with default value not overrided - an := Animal{From: "nerdz"} - - if DB.Save(&an).Error != nil { - t.Errorf("No error should happen when create an record without std primary key") - } - - // We must fetch the value again, to have the default fields updated - // (We can't do this in the update statements, since sql default can be expressions - // And be different from the fields' type (eg. a time.Time fields has a default value of "now()" - DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an) - - if an.Name != "galeone" { - t.Errorf("Default value should fill the field. But got %v", an.Name) - } -} - -func TestAnonymousScanner(t *testing.T) { - user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_scanner") - if user2.Role.Name != "admin" { - t.Errorf("Should be able to get anonymous scanner") - } - - if !user2.Role.IsAdmin() { - t.Errorf("Should be able to get anonymous scanner") - } -} - -func TestAnonymousField(t *testing.T) { - user := User{Name: "anonymous_field", Company: Company{Name: "company"}} - DB.Save(&user) - - var user2 User - DB.First(&user2, "name = ?", "anonymous_field") - DB.Model(&user2).Related(&user2.Company) - if user2.Company.Name != "company" { - t.Errorf("Should be able to get anonymous field") - } -} - -func TestSelectWithCreate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_create") - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name != user.Name || queryuser.Age == user.Age { - t.Errorf("Should only create users with name column") - } - - if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 || - queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 { - t.Errorf("Should only create selected relationships") - } -} - -func TestOmitWithCreate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_create") - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user) - - var queryuser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id) - - if queryuser.Name == user.Name || queryuser.Age != user.Age { - t.Errorf("Should only create users with age column") - } - - if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 || - queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 { - t.Errorf("Should not create omitted relationships") - } -} diff --git a/customize_column_test.go b/customize_column_test.go deleted file mode 100644 index c96b2d40..00000000 --- a/customize_column_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type CustomizeColumn struct { - ID int64 `gorm:"column:mapped_id; primary_key:yes"` - Name string `gorm:"column:mapped_name"` - Date *time.Time `gorm:"column:mapped_time"` -} - -// Make sure an ignored field does not interfere with another field's custom -// column name that matches the ignored field. -type CustomColumnAndIgnoredFieldClash struct { - Body string `sql:"-"` - RawBody string `gorm:"column:body"` -} - -func TestCustomizeColumn(t *testing.T) { - col := "mapped_name" - DB.DropTable(&CustomizeColumn{}) - DB.AutoMigrate(&CustomizeColumn{}) - - scope := DB.NewScope(&CustomizeColumn{}) - if !scope.Dialect().HasColumn(scope.TableName(), col) { - t.Errorf("CustomizeColumn should have column %s", col) - } - - col = "mapped_id" - if scope.PrimaryKey() != col { - t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey()) - } - - expected := "foo" - now := time.Now() - cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} - - if count := DB.Create(&cc).RowsAffected; count != 1 { - t.Error("There should be one record be affected when create record") - } - - var cc1 CustomizeColumn - DB.First(&cc1, 666) - - if cc1.Name != expected { - t.Errorf("Failed to query CustomizeColumn") - } - - cc.Name = "bar" - DB.Save(&cc) - - var cc2 CustomizeColumn - DB.First(&cc2, 666) - if cc2.Name != "bar" { - t.Errorf("Failed to query CustomizeColumn") - } -} - -func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { - DB.DropTable(&CustomColumnAndIgnoredFieldClash{}) - if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil { - t.Errorf("Should not raise error: %s", err) - } -} - -type CustomizePerson struct { - IdPerson string `gorm:"column:idPerson;primary_key:true"` - Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"` -} - -type CustomizeAccount struct { - IdAccount string `gorm:"column:idAccount;primary_key:true"` - Name string -} - -func TestManyToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount") - DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{}) - - account := CustomizeAccount{IdAccount: "account", Name: "id1"} - person := CustomizePerson{ - IdPerson: "person", - Accounts: []CustomizeAccount{account}, - } - - if err := DB.Create(&account).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if err := DB.Create(&person).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - var person1 CustomizePerson - scope := DB.NewScope(nil) - if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil { - t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err) - } - - if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" { - t.Errorf("should preload correct accounts") - } -} - -type CustomizeUser struct { - gorm.Model - Email string `sql:"column:email_address"` -} - -type CustomizeInvitation struct { - gorm.Model - Address string `sql:"column:invitation"` - Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"` -} - -func TestOneToOneWithCustomizedColumn(t *testing.T) { - DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{}) - DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{}) - - user := CustomizeUser{ - Email: "hello@example.com", - } - invitation := CustomizeInvitation{ - Address: "hello@example.com", - } - - DB.Create(&user) - DB.Create(&invitation) - - var invitation2 CustomizeInvitation - if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - if invitation2.Person.Email != user.Email { - t.Errorf("Should preload one to one relation with customize foreign keys") - } -} - -type PromotionDiscount struct { - gorm.Model - Name string - Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"` - Rule *PromotionRule `gorm:"ForeignKey:discount_id"` - Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"` -} - -type PromotionBenefit struct { - gorm.Model - Name string - PromotionID uint - Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"` -} - -type PromotionCoupon struct { - gorm.Model - Code string - DiscountID uint - Discount PromotionDiscount -} - -type PromotionRule struct { - gorm.Model - Name string - Begin *time.Time - End *time.Time - DiscountID uint - Discount *PromotionDiscount -} - -func TestOneToManyWithCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{}) - - discount := PromotionDiscount{ - Name: "Happy New Year", - Coupons: []*PromotionCoupon{ - {Code: "newyear1"}, - {Code: "newyear2"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Coupons) != 2 { - t.Errorf("should find two coupons") - } - - var coupon PromotionCoupon - if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if coupon.Discount.Name != "Happy New Year" { - t.Errorf("should preload discount from coupon") - } -} - -func TestHasOneWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionRule{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{}) - - var begin = time.Now() - var end = time.Now().Add(24 * time.Hour) - discount := PromotionDiscount{ - Name: "Happy New Year 2", - Rule: &PromotionRule{ - Name: "time_limited", - Begin: &begin, - End: &end, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) { - t.Errorf("Should be able to preload Rule") - } - - var rule PromotionRule - if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if rule.Discount.Name != "Happy New Year 2" { - t.Errorf("should preload discount from rule") - } -} - -func TestBelongsToWithPartialCustomizedColumn(t *testing.T) { - DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{}) - DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{}) - - discount := PromotionDiscount{ - Name: "Happy New Year 3", - Benefits: []PromotionBenefit{ - {Name: "free cod"}, - {Name: "free shipping"}, - }, - } - - if err := DB.Create(&discount).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - var discount1 PromotionDiscount - if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if len(discount.Benefits) != 2 { - t.Errorf("should find two benefits") - } - - var benefit PromotionBenefit - if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil { - t.Errorf("no error should happen but got %v", err) - } - - if benefit.Discount.Name != "Happy New Year 3" { - t.Errorf("should preload discount from coupon") - } -} - -type SelfReferencingUser struct { - gorm.Model - Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"` -} - -func TestSelfReferencingMany2ManyColumn(t *testing.T) { - DB.DropTable(&SelfReferencingUser{}, "UserFriends") - DB.AutoMigrate(&SelfReferencingUser{}) - - friend := SelfReferencingUser{} - if err := DB.Create(&friend).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } - - user := SelfReferencingUser{ - Friends: []*SelfReferencingUser{&friend}, - } - if err := DB.Create(&user).Error; err != nil { - t.Errorf("no error should happen, but got %v", err) - } -} diff --git a/delete_test.go b/delete_test.go deleted file mode 100644 index 043641f7..00000000 --- a/delete_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" -) - -func TestDelete(t *testing.T) { - user1, user2 := User{Name: "delete1"}, User{Name: "delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if err := DB.Delete(&user1).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } - - if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("Other users that not deleted should be found-able") - } -} - -func TestInlineDelete(t *testing.T) { - user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"} - DB.Save(&user1) - DB.Save(&user2) - - if DB.Delete(&User{}, user1.Id).Error != nil { - t.Errorf("No error should happen when delete a record") - } else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } - - if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { - t.Errorf("No error should happen when delete a record, err=%s", err) - } else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() { - t.Errorf("User can't be found after delete") - } -} - -func TestSoftDelete(t *testing.T) { - type User struct { - Id int64 - Name string - DeletedAt *time.Time - } - DB.AutoMigrate(&User{}) - - user := User{Name: "soft_delete"} - DB.Save(&user) - DB.Delete(&user) - - if DB.First(&User{}, "name = ?", user.Name).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&user) - if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} - -func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) { - creditCard := CreditCard{Number: "411111111234567"} - DB.Save(&creditCard) - DB.Delete(&creditCard) - - if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" { - t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`") - } - - if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil { - t.Errorf("Can't find a soft deleted record") - } - - if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil { - t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err) - } - - DB.Unscoped().Delete(&creditCard) - if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() { - t.Errorf("Can't find permanently deleted record") - } -} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index e879588b..00000000 --- a/dialect.go +++ /dev/null @@ -1,116 +0,0 @@ -package gorm - -import ( - "database/sql" - "fmt" - "reflect" - "strconv" - "strings" -) - -// Dialect interface contains behaviors that differ across SQL database -type Dialect interface { - // GetName get dialect's name - GetName() string - - // SetDB set db for dialect - SetDB(db SQLCommon) - - // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 - BindVar(i int) string - // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name - Quote(key string) string - // DataTypeOf return data's sql type - DataTypeOf(field *StructField) string - - // HasIndex check has index or not - HasIndex(tableName string, indexName string) bool - // HasForeignKey check has foreign key or not - HasForeignKey(tableName string, foreignKeyName string) bool - // RemoveIndex remove index - RemoveIndex(tableName string, indexName string) error - // HasTable check has table or not - HasTable(tableName string) bool - // HasColumn check has column or not - HasColumn(tableName string, columnName string) bool - - // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset interface{}) string - // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` - SelectFromDummyTable() string - // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` - LastInsertIDReturningSuffix(tableName, columnName string) string - - // BuildForeignKeyName returns a foreign key name for the given table, field and reference - BuildForeignKeyName(tableName, field, dest string) string - - // CurrentDatabase return current database name - CurrentDatabase() string -} - -var dialectsMap = map[string]Dialect{} - -func newDialect(name string, db SQLCommon) Dialect { - if value, ok := dialectsMap[name]; ok { - dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) - dialect.SetDB(db) - return dialect - } - - fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name) - commontDialect := &commonDialect{} - commontDialect.SetDB(db) - return commontDialect -} - -// RegisterDialect register new dialect -func RegisterDialect(name string, dialect Dialect) { - dialectsMap[name] = dialect -} - -// ParseFieldStructForDialect get field's sql data type -var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { - // Get redirected field type - var ( - reflectType = field.Struct.Type - dataType = field.TagSettings["TYPE"] - ) - - for reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Get redirected field value - fieldValue = reflect.Indirect(reflect.New(reflectType)) - - if gormDataType, ok := fieldValue.Interface().(interface { - GormDataType(Dialect) string - }); ok { - dataType = gormDataType.GormDataType(dialect) - } - - // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) - } - } - getScannerValue(fieldValue) - - // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(num) - } else { - size = 255 - } - - // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - - return fieldValue, dataType, size, strings.TrimSpace(additionalType) -} diff --git a/dialect_common.go b/dialect_common.go deleted file mode 100644 index a99627f2..00000000 --- a/dialect_common.go +++ /dev/null @@ -1,156 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" -) - -// DefaultForeignKeyNamer contains the default foreign key name generator method -type DefaultForeignKeyNamer struct { -} - -type commonDialect struct { - db SQLCommon - DefaultForeignKeyNamer -} - -func init() { - RegisterDialect("common", &commonDialect{}) -} - -func (commonDialect) GetName() string { - return "common" -} - -func (s *commonDialect) SetDB(db SQLCommon) { - s.db = db -} - -func (commonDialect) BindVar(i int) string { - return "$$$" // ? -} - -func (commonDialect) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "BOOLEAN" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "INTEGER AUTO_INCREMENT" - } else { - sqlType = "INTEGER" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - sqlType = "BIGINT AUTO_INCREMENT" - } else { - sqlType = "BIGINT" - } - case reflect.Float32, reflect.Float64: - sqlType = "FLOAT" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("VARCHAR(%d)", size) - } else { - sqlType = "VARCHAR(65532)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "TIMESTAMP" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("BINARY(%d)", size) - } else { - sqlType = "BINARY(65532)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s commonDialect) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count) - return count > 0 -} - -func (s commonDialect) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName)) - return err -} - -func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s commonDialect) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count) - return count > 0 -} - -func (s commonDialect) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s commonDialect) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - } - } - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - return -} - -func (commonDialect) SelectFromDummyTable() string { - return "" -} - -func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} - -func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string { - keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") - return keyName -} - -// IsByteArrayOrSlice returns true of the reflected value is an array or slice -func IsByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} diff --git a/dialect_mysql.go b/dialect_mysql.go deleted file mode 100644 index 6fcd0079..00000000 --- a/dialect_mysql.go +++ /dev/null @@ -1,176 +0,0 @@ -package gorm - -import ( - "crypto/sha1" - "fmt" - "reflect" - "regexp" - "strconv" - "strings" - "time" - "unicode/utf8" -) - -type mysql struct { - commonDialect -} - -func init() { - RegisterDialect("mysql", &mysql{}) -} - -func (mysql) GetName() string { - return "mysql" -} - -func (mysql) Quote(key string) string { - return fmt.Sprintf("`%s`", key) -} - -// Get Data Type for MySQL Dialect -func (s *mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - // MySQL allows only one auto increment column per table, and it must - // be a KEY column. - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { - if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey { - delete(field.TagSettings, "AUTO_INCREMENT") - } - } - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "tinyint AUTO_INCREMENT" - } else { - sqlType = "tinyint" - } - case reflect.Int, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int AUTO_INCREMENT" - } else { - sqlType = "int" - } - case reflect.Uint8: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "tinyint unsigned AUTO_INCREMENT" - } else { - sqlType = "tinyint unsigned" - } - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int unsigned AUTO_INCREMENT" - } else { - sqlType = "int unsigned" - } - case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint AUTO_INCREMENT" - } else { - sqlType = "bigint" - } - case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint unsigned AUTO_INCREMENT" - } else { - sqlType = "bigint unsigned" - } - case reflect.Float32, reflect.Float64: - sqlType = "double" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "longtext" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" - } else { - sqlType = "timestamp NULL" - } - } - default: - if IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "longblob" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mysql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", parsedLimit) - - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", parsedOffset) - } - } - } - } - return -} - -func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s mysql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DATABASE()").Scan(&name) - return -} - -func (mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s mysql) BuildForeignKeyName(tableName, field, dest string) string { - keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest) - if utf8.RuneCountInString(keyName) <= 64 { - return keyName - } - h := sha1.New() - h.Write([]byte(keyName)) - bs := h.Sum(nil) - - // sha1 is 40 digits, keep first 24 characters of destination - destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_")) - if len(destRunes) > 24 { - destRunes = destRunes[:24] - } - - return fmt.Sprintf("%s%x", string(destRunes), bs) -} diff --git a/dialect_postgres.go b/dialect_postgres.go deleted file mode 100644 index 6fdf4df1..00000000 --- a/dialect_postgres.go +++ /dev/null @@ -1,132 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type postgres struct { - commonDialect -} - -func init() { - RegisterDialect("postgres", &postgres{}) - RegisterDialect("cloudsqlpostgres", &postgres{}) -} - -func (postgres) GetName() string { - return "postgres" -} - -func (postgres) BindVar(i int) string { - return fmt.Sprintf("$%v", i) -} - -func (s *postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "serial" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint32, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigserial" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "numeric" - case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { - size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different - } - - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "timestamp with time zone" - } - case reflect.Map: - if dataValue.Type().Name() == "Hstore" { - sqlType = "hstore" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "bytea" - if isUUID(dataValue) { - sqlType = "uuid" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s postgres) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) - return count > 0 -} - -func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { - var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) - return count > 0 -} - -func (s postgres) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) - return count > 0 -} - -func (s postgres) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) - return count > 0 -} - -func (s postgres) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) - return -} - -func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { - return fmt.Sprintf("RETURNING %v.%v", tableName, key) -} - -func (postgres) SupportLastInsertID() bool { - return false -} - -func isUUID(value reflect.Value) bool { - if value.Kind() != reflect.Array || value.Type().Len() != 16 { - return false - } - typename := value.Type().Name() - lower := strings.ToLower(typename) - return "uuid" == lower || "guid" == lower -} diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go deleted file mode 100644 index de9c05cb..00000000 --- a/dialect_sqlite3.go +++ /dev/null @@ -1,107 +0,0 @@ -package gorm - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type sqlite3 struct { - commonDialect -} - -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} - -func (sqlite3) GetName() string { - return "sqlite3" -} - -// Get Data Type for Sqlite Dialect -func (s *sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "integer primary key autoincrement" - } else { - sqlType = "integer" - } - case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "integer primary key autoincrement" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "real" - case reflect.String: - if size > 0 && size < 65532 { - sqlType = fmt.Sprintf("varchar(%d)", size) - } else { - sqlType = "text" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetime" - } - default: - if IsByteArrayOrSlice(dataValue) { - sqlType = "blob" - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s sqlite3) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count) - return count > 0 -} - -func (s sqlite3) CurrentDatabase() (name string) { - var ( - ifaces = make([]interface{}, 3) - pointers = make([]*string, 3) - i int - ) - for i = 0; i < 3; i++ { - ifaces[i] = &pointers[i] - } - if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil { - return - } - if pointers[1] != nil { - name = *pointers[1] - } - return -} diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go deleted file mode 100644 index de2ae7ca..00000000 --- a/dialects/mssql/mssql.go +++ /dev/null @@ -1,170 +0,0 @@ -package mssql - -import ( - "fmt" - "reflect" - "strconv" - "strings" - "time" - - _ "github.com/denisenkom/go-mssqldb" - "github.com/jinzhu/gorm" -) - -func setIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - for _, field := range scope.PrimaryFields() { - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName())) - scope.InstanceSet("mssql:identity_insert_on", true) - } - } - } -} - -func turnOffIdentityInsert(scope *gorm.Scope) { - if scope.Dialect().GetName() == "mssql" { - if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok { - scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName())) - } - } -} - -func init() { - gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert) - gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert) - gorm.RegisterDialect("mssql", &mssql{}) -} - -type mssql struct { - db gorm.SQLCommon - gorm.DefaultForeignKeyNamer -} - -func (mssql) GetName() string { - return "mssql" -} - -func (s *mssql) SetDB(db gorm.SQLCommon) { - s.db = db -} - -func (mssql) BindVar(i int) string { - return "$$$" // ? -} - -func (mssql) Quote(key string) string { - return fmt.Sprintf(`"%s"`, key) -} - -func (s *mssql) DataTypeOf(field *gorm.StructField) string { - var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s) - - if sqlType == "" { - switch dataValue.Kind() { - case reflect.Bool: - sqlType = "bit" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "int IDENTITY(1,1)" - } else { - sqlType = "int" - } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { - field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT" - sqlType = "bigint IDENTITY(1,1)" - } else { - sqlType = "bigint" - } - case reflect.Float32, reflect.Float64: - sqlType = "float" - case reflect.String: - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("nvarchar(%d)", size) - } else { - sqlType = "nvarchar(max)" - } - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - sqlType = "datetimeoffset" - } - default: - if gorm.IsByteArrayOrSlice(dataValue) { - if size > 0 && size < 8000 { - sqlType = fmt.Sprintf("varbinary(%d)", size) - } else { - sqlType = "varbinary(max)" - } - } - } - } - - if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - -func (s mssql) HasIndex(tableName string, indexName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count) - return count > 0 -} - -func (s mssql) RemoveIndex(tableName string, indexName string) error { - _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName))) - return err -} - -func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool { - return false -} - -func (s mssql) HasTable(tableName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count) - return count > 0 -} - -func (s mssql) HasColumn(tableName string, columnName string) bool { - var count int - s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count) - return count > 0 -} - -func (s mssql) CurrentDatabase() (name string) { - s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name) - return -} - -func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { - if offset != nil { - if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { - sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) - } - } - if limit != nil { - if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { - if sql == "" { - // add default zero offset - sql += " OFFSET 0 ROWS" - } - sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit) - } - } - return -} - -func (mssql) SelectFromDummyTable() string { - return "" -} - -func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { - return "" -} diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go deleted file mode 100644 index 9deba48a..00000000 --- a/dialects/mysql/mysql.go +++ /dev/null @@ -1,3 +0,0 @@ -package mysql - -import _ "github.com/go-sql-driver/mysql" diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go deleted file mode 100644 index b8e76891..00000000 --- a/dialects/postgres/postgres.go +++ /dev/null @@ -1,77 +0,0 @@ -package postgres - -import ( - "database/sql" - "database/sql/driver" - - _ "github.com/lib/pq" - "github.com/lib/pq/hstore" - "encoding/json" - "errors" - "fmt" -) - -type Hstore map[string]*string - -// Value get value of Hstore -func (h Hstore) Value() (driver.Value, error) { - hstore := hstore.Hstore{Map: map[string]sql.NullString{}} - if len(h) == 0 { - return nil, nil - } - - for key, value := range h { - var s sql.NullString - if value != nil { - s.String = *value - s.Valid = true - } - hstore.Map[key] = s - } - return hstore.Value() -} - -// Scan scan value into Hstore -func (h *Hstore) Scan(value interface{}) error { - hstore := hstore.Hstore{} - - if err := hstore.Scan(value); err != nil { - return err - } - - if len(hstore.Map) == 0 { - return nil - } - - *h = Hstore{} - for k := range hstore.Map { - if hstore.Map[k].Valid { - s := hstore.Map[k].String - (*h)[k] = &s - } else { - (*h)[k] = nil - } - } - - return nil -} - -// Jsonb Postgresql's JSONB data type -type Jsonb struct { - json.RawMessage -} - -// Value get value of Jsonb -func (j Jsonb) Value() (driver.Value, error) { - return j.MarshalJSON() -} - -// Scan scan value into Jsonb -func (j *Jsonb) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) - } - - return json.Unmarshal(bytes, j) -} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go deleted file mode 100644 index 069ad3a9..00000000 --- a/dialects/sqlite/sqlite.go +++ /dev/null @@ -1,3 +0,0 @@ -package sqlite - -import _ "github.com/mattn/go-sqlite3" diff --git a/embedded_struct_test.go b/embedded_struct_test.go deleted file mode 100644 index 91dd0563..00000000 --- a/embedded_struct_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package gorm_test - -import "testing" - -type BasePost struct { - Id int64 - Title string - URL string -} - -type Author struct { - ID string - Name string - Email string -} - -type HNPost struct { - BasePost - Author `gorm:"embedded_prefix:user_"` // Embedded struct - Upvotes int32 -} - -type EngadgetPost struct { - BasePost BasePost `gorm:"embedded"` - Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct - ImageUrl string -} - -func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) { - dialect := DB.NewScope(&EngadgetPost{}).Dialect() - engadgetPostScope := DB.NewScope(&EngadgetPost{}) - if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") { - t.Errorf("should has prefix for embedded columns") - } - - if len(engadgetPostScope.PrimaryFields()) != 1 { - t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields())) - } - - hnScope := DB.NewScope(&HNPost{}) - if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") { - t.Errorf("should has prefix for embedded columns") - } -} - -func TestSaveAndQueryEmbeddedStruct(t *testing.T) { - DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) - DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) - var news HNPost - if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if news.Title != "hn_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) - var egNews EngadgetPost - if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { - t.Errorf("no error should happen when query with embedded struct, but got %v", err) - } else if egNews.BasePost.Title != "engadget_news" { - t.Errorf("embedded struct's value should be scanned correctly") - } - - if DB.NewScope(&HNPost{}).PrimaryField() == nil { - t.Errorf("primary key with embedded struct should works") - } - - for _, field := range DB.NewScope(&HNPost{}).Fields() { - if field.Name == "BasePost" { - t.Errorf("scope Fields should not contain embedded struct") - } - } -} diff --git a/errors.go b/errors.go index 6845188e..f1f6c137 100644 --- a/errors.go +++ b/errors.go @@ -2,59 +2,41 @@ package gorm import ( "errors" - "strings" + + "gorm.io/gorm/logger" ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct - ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") + // ErrRecordNotFound record not found error + ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") - // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` - ErrCantStartTransaction = errors.New("can't start transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("not implemented") + // ErrMissingWhereClause missing where clause + ErrMissingWhereClause = errors.New("WHERE conditions required") + // ErrUnsupportedRelation unsupported relations + ErrUnsupportedRelation = errors.New("unsupported relations") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") + // ErrRegistered registered + ErrRegistered = errors.New("registered") + // ErrInvalidField invalid field + ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") + // ErrInvalidDB invalid db + ErrInvalidDB = errors.New("invalid db") + // ErrInvalidValue invalid value + ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") + // ErrInvalidValueOfLength invalid values do not match length + ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ) - -// Errors contains all happened errors -type Errors []error - -// GetErrors gets all happened errors -func (errs Errors) GetErrors() []error { - return errs -} - -// Add adds an error -func (errs Errors) Add(newErrors ...error) Errors { - for _, err := range newErrors { - if err == nil { - continue - } - - if errors, ok := err.(Errors); ok { - errs = errs.Add(errors...) - } else { - ok = true - for _, e := range errs { - if err == e { - ok = false - } - } - if ok { - errs = append(errs, err) - } - } - } - return errs -} - -// Error format happened errors -func (errs Errors) Error() string { - var errors = []string{} - for _, e := range errs { - errors = append(errors, e.Error()) - } - return strings.Join(errors, "; ") -} diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 9a428dec..00000000 --- a/errors_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package gorm_test - -import ( - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { - errs := []error{errors.New("First"), errors.New("Second")} - - gErrs := gorm.Errors(errs) - gErrs = gErrs.Add(errors.New("Third")) - gErrs = gErrs.Add(gErrs) - - if gErrs.Error() != "First; Second; Third" { - t.Fatalf("Gave wrong error, got %s", gErrs.Error()) - } -} diff --git a/field.go b/field.go deleted file mode 100644 index 11c410b0..00000000 --- a/field.go +++ /dev/null @@ -1,58 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" -) - -// Field model field definition -type Field struct { - *StructField - IsBlank bool - Field reflect.Value -} - -// Set set a value to the field -func (field *Field) Set(value interface{}) (err error) { - if !field.Field.IsValid() { - return errors.New("field value not valid") - } - - if !field.Field.CanAddr() { - return ErrUnaddressable - } - - reflectValue, ok := value.(reflect.Value) - if !ok { - reflectValue = reflect.ValueOf(value) - } - - fieldValue := field.Field - if reflectValue.IsValid() { - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else { - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.Struct.Type.Elem())) - } - fieldValue = fieldValue.Elem() - } - - if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { - fieldValue.Set(reflectValue.Convert(fieldValue.Type())) - } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) - } else { - err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) - } - } - } else { - field.Field.Set(reflect.Zero(field.Field.Type())) - } - - field.IsBlank = isBlank(field.Field) - return err -} diff --git a/field_test.go b/field_test.go deleted file mode 100644 index 30e9a778..00000000 --- a/field_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -type CalculateField struct { - gorm.Model - Name string - Children []CalculateFieldChild - Category CalculateFieldCategory - EmbeddedField -} - -type EmbeddedField struct { - EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"` -} - -type CalculateFieldChild struct { - gorm.Model - CalculateFieldID uint - Name string -} - -type CalculateFieldCategory struct { - gorm.Model - CalculateFieldID uint - Name string -} - -func TestCalculateField(t *testing.T) { - var field CalculateField - var scope = DB.NewScope(&field) - if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil { - t.Errorf("Should calculate fields correctly for the first time") - } - - if field, ok := scope.FieldByName("embedded_name"); !ok { - t.Errorf("should find embedded field") - } else if _, ok := field.TagSettings["NOT NULL"]; !ok { - t.Errorf("should find embedded field's tag settings") - } -} diff --git a/finisher_api.go b/finisher_api.go new file mode 100644 index 00000000..b5cbfaa6 --- /dev/null +++ b/finisher_api.go @@ -0,0 +1,649 @@ +package gorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Create insert the value into database +func (db *DB) Create(value interface{}) (tx *DB) { + if db.CreateBatchSize > 0 { + return db.CreateInBatches(value, db.CreateBatchSize) + } + + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + return +} + +// CreateInBatches insert the value in batches into database +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + var rowsAffected int64 + tx = db.getInstance() + + callFc := func(tx *DB) error { + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + for i := 0; i < reflectLen; i += batchSize { + ends := i + batchSize + if ends > reflectLen { + ends = reflectLen + } + + subtx := tx.getInstance() + subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() + subtx.callbacks.Create().Execute(subtx) + if subtx.Error != nil { + return subtx.Error + } + rowsAffected += subtx.RowsAffected + } + return nil + } + + if tx.SkipDefaultTransaction { + tx.AddError(callFc(tx.Session(&Session{}))) + } else { + tx.AddError(tx.Transaction(callFc)) + } + + tx.RowsAffected = rowsAffected + default: + tx = db.getInstance() + tx.Statement.Dest = value + tx.callbacks.Create().Execute(tx) + } + return +} + +// Save update value in database, if the value doesn't have primary key, will insert it +func (db *DB) Save(value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = value + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } + tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + case reflect.Struct: + if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + for _, pf := range tx.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(reflectValue); isZero { + tx.callbacks.Create().Execute(tx) + return + } + } + } + + fallthrough + default: + selectedUpdate := len(tx.Statement.Selects) != 0 + // when updating, use all fields including those zero-value fields + if !selectedUpdate { + tx.Statement.Selects = append(tx.Statement.Selects, "*") + } + + tx.callbacks.Update().Execute(tx) + + if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + return tx.Create(value) + } + } + } + + return +} + +// First find first record that match given conditions, order by primary key +func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Take return a record that match given conditions, the order will depend on the database implementation +func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Last find last record that match given conditions, order by primary key +func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Desc: true, + }) + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.RaiseErrorOnNotFound = true + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// Find find records that match given conditions +func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +// FindInBatches find records in batches +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{}) + queryDB = tx + rowsAffected int64 + batch int + ) + + for { + result := queryDB.Limit(batchSize).Find(dest) + rowsAffected += result.RowsAffected + batch++ + + if result.Error == nil && result.RowsAffected != 0 { + tx.AddError(fc(result, batch)) + } else if result.Error != nil { + tx.AddError(result.Error) + } + + if tx.Error != nil || int(result.RowsAffected) < batchSize { + break + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } else { + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + } + } + } + + tx.RowsAffected = rowsAffected + return tx +} + +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + } + } else if andCond, ok := expr.(clause.AndConditions); ok { + tx.assignInterfacesToValue(andCond.Exprs) + } + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } + } + } else if len(values) > 0 { + if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { + tx.assignInterfacesToValue(exprs) + } + return + } + } + } +} + +func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + return +} + +func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { + if c, ok := tx.Statement.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok { + tx.assignInterfacesToValue(where.Exprs) + } + } + + // initialize with attrs, conds + if len(tx.Statement.attrs) > 0 { + tx.assignInterfacesToValue(tx.Statement.attrs...) + } + + // initialize with attrs, conds + if len(tx.Statement.assigns) > 0 { + tx.assignInterfacesToValue(tx.Statement.assigns...) + } + + return tx.Create(dest) + } else if len(db.Statement.assigns) > 0 { + exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + assigns := map[string]interface{}{} + for _, expr := range exprs { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + assigns[column] = eq.Value + case clause.Column: + assigns[column.Name] = eq.Value + default: + } + } + } + + return tx.Model(dest).Updates(assigns) + } + + return db +} + +// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Update(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.callbacks.Update().Execute(tx) + return +} + +// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +func (db *DB) Updates(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.callbacks.Update().Execute(tx) + return +} + +func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = map[string]interface{}{column: value} + tx.Statement.SkipHooks = true + tx.callbacks.Update().Execute(tx) + return +} + +func (db *DB) UpdateColumns(values interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.Dest = values + tx.Statement.SkipHooks = true + tx.callbacks.Update().Execute(tx) + return +} + +// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { + tx = db.getInstance() + if len(conds) > 0 { + if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + tx.Statement.AddClause(clause.Where{Exprs: exprs}) + } + } + tx.Statement.Dest = value + tx.callbacks.Delete().Execute(tx) + return +} + +func (db *DB) Count(count *int64) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model == nil { + tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() + } + + if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { + defer func() { + db.Statement.Clauses["SELECT"] = selectClause + }() + } else { + defer delete(tx.Statement.Clauses, "SELECT") + } + + if len(tx.Statement.Selects) == 0 { + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { + expr := clause.Expr{SQL: "count(1)"} + + if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] + fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) + if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName + } + } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } + } + } + + tx.Statement.AddClause(clause.Select{Expression: expr}) + } + + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(db.Statement.Clauses, "ORDER BY") + defer func() { + db.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + + tx.Statement.Dest = count + tx.callbacks.Query().Execute(tx) + if tx.RowsAffected != 1 { + *count = tx.RowsAffected + } + return +} + +func (db *DB) Row() *sql.Row { + tx := db.getInstance().InstanceSet("rows", false) + tx.callbacks.Row().Execute(tx) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row +} + +func (db *DB) Rows() (*sql.Rows, error) { + tx := db.getInstance().InstanceSet("rows", true) + tx.callbacks.Row().Execute(tx) + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error +} + +// Scan scan value to a struct +func (db *DB) Scan(dest interface{}) (tx *DB) { + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + + tx = db.getInstance() + tx.Config = &config + + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } else { + tx.RowsAffected = 0 + } + } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger + return +} + +// Pluck used to query single column from a model as a map +// var ages []int64 +// db.Find(&users).Pluck("age", &ages) +func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { + tx = db.getInstance() + if tx.Statement.Model != nil { + if tx.Statement.Parse(tx.Statement.Model) == nil { + if f := tx.Statement.Schema.LookUpField(column); f != nil { + column = f.DBName + } + } + } else if tx.Statement.Table == "" { + tx.AddError(ErrModelValueRequired) + } + + if len(tx.Statement.Selects) != 1 { + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) + tx.Statement.AddClauseIfNotExists(clause.Select{ + Distinct: tx.Statement.Distinct, + Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, + }) + } + tx.Statement.Dest = dest + tx.callbacks.Query().Execute(tx) + return +} + +func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { + tx := db.getInstance() + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } + tx.Statement.Dest = dest + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } + Scan(rows, tx, true) + return tx.Error +} + +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { + panicked := true + + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + if !db.DisableNestedTransaction { + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + } + + if err == nil { + err = fc(db.Session(&Session{})) + } + } else { + tx := db.Begin(opts...) + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + if err = tx.Error; err == nil { + err = fc(tx) + } + + if err == nil { + err = tx.Commit().Error + } + } + + panicked = false + return +} + +// Begin begins a transaction +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + // clone statement + tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) + opt *sql.TxOptions + err error + ) + + if len(opts) > 0 { + opt = opts[0] + } + + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else { + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx +} + +// Commit commit a transaction +func (db *DB) Commit() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Commit()) + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +// Rollback rollback a transaction +func (db *DB) Rollback() *DB { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } + } else { + db.AddError(ErrInvalidTransaction) + } + return db +} + +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.SavePoint(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + db.AddError(savePointer.RollbackTo(db, name)) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +// Exec execute raw sql +func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { + tx = db.getInstance() + tx.Statement.SQL = strings.Builder{} + + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) + } else { + clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) + } + + tx.callbacks.Raw().Execute(tx) + return +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..d95d3f10 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module gorm.io/gorm + +go 1.14 + +require ( + github.com/jinzhu/inflection v1.0.0 + github.com/jinzhu/now v1.1.2 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..c66a6b57 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= diff --git a/gorm.go b/gorm.go new file mode 100644 index 00000000..e105a933 --- /dev/null +++ b/gorm.go @@ -0,0 +1,441 @@ +package gorm + +import ( + "context" + "database/sql" + "fmt" + "sort" + "sync" + "time" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +// for Config.cacheStore store PreparedStmtDB key +const preparedStmtDBKey = "preparedStmt" + +// Config GORM config +type Config struct { + // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity + // You can disable it by setting `SkipDefaultTransaction` to true + SkipDefaultTransaction bool + // NamingStrategy tables, columns naming strategy + NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool + // Logger + Logger logger.Interface + // NowFunc the function to be used when creating a new timestamp + NowFunc func() time.Time + // DryRun generate sql without execute + DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // DisableAutomaticPing + DisableAutomaticPing bool + // DisableForeignKeyConstraintWhenMigrating + DisableForeignKeyConstraintWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool + // CreateBatchSize default create batch size + CreateBatchSize int + + // ClauseBuilders clause builder + ClauseBuilders map[string]clause.ClauseBuilder + // ConnPool db conn pool + ConnPool ConnPool + // Dialector database dialector + Dialector + // Plugins registered plugins + Plugins map[string]Plugin + + callbacks *callbacks + cacheStore *sync.Map +} + +func (c *Config) Apply(config *Config) error { + if config != c { + *config = *c + } + return nil +} + +func (c *Config) AfterInitialize(db *DB) error { + if db != nil { + for _, plugin := range c.Plugins { + if err := plugin.Initialize(db); err != nil { + return err + } + } + } + return nil +} + +type Option interface { + Apply(*Config) error + AfterInitialize(*DB) error +} + +// DB GORM DB definition +type DB struct { + *Config + Error error + RowsAffected int64 + Statement *Statement + clone int +} + +// Session session config when create session with Session() method +type Session struct { + DryRun bool + PrepareStmt bool + NewDB bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int +} + +// Open initialize db session based on dialector +func Open(dialector Dialector, opts ...Option) (db *DB, err error) { + config := &Config{} + + sort.Slice(opts, func(i, j int) bool { + _, isConfig := opts[i].(*Config) + _, isConfig2 := opts[j].(*Config) + return isConfig && !isConfig2 + }) + + for _, opt := range opts { + if opt != nil { + if err := opt.Apply(config); err != nil { + return nil, err + } + defer func(opt Option) { + if errr := opt.AfterInitialize(db); errr != nil { + err = errr + } + }(opt) + } + } + + if d, ok := dialector.(interface{ Apply(*Config) error }); ok { + if err = d.Apply(config); err != nil { + return + } + } + + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{} + } + + if config.Logger == nil { + config.Logger = logger.Default + } + + if config.NowFunc == nil { + config.NowFunc = func() time.Time { return time.Now().Local() } + } + + if dialector != nil { + config.Dialector = dialector + } + + if config.Plugins == nil { + config.Plugins = map[string]Plugin{} + } + + if config.cacheStore == nil { + config.cacheStore = &sync.Map{} + } + + db = &DB{Config: config, clone: 1} + + db.callbacks = initializeCallbacks(db) + + if config.ClauseBuilders == nil { + config.ClauseBuilders = map[string]clause.ClauseBuilder{} + } + + if config.Dialector != nil { + err = config.Dialector.Initialize(db) + } + + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]Stmt{}, + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + + if config.PrepareStmt { + db.ConnPool = preparedStmt + } + + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + + if err == nil && !config.DisableAutomaticPing { + if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { + err = pinger.Ping() + } + } + + if err != nil { + config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) + } + + return +} + +// Session create new db session +func (db *DB) Session(config *Session) *DB { + var ( + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + Error: db.Error, + clone: 1, + } + ) + if config.CreateBatchSize > 0 { + tx.Config.CreateBatchSize = config.CreateBatchSize + } + + if config.SkipDefaultTransaction { + tx.Config.SkipDefaultTransaction = true + } + + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + + if config.Context != nil || config.PrepareStmt || config.SkipHooks { + tx.Statement = tx.Statement.clone() + tx.Statement.DB = tx + } + + if config.Context != nil { + tx.Statement.Context = config.Context + } + + if config.PrepareStmt { + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { + preparedStmt := v.(*PreparedStmtDB) + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true + } + } + + if config.SkipHooks { + tx.Statement.SkipHooks = true + } + + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + + if !config.NewDB { + tx.clone = 2 + } + + if config.DryRun { + tx.Config.DryRun = true + } + + if config.QueryFields { + tx.Config.QueryFields = true + } + + if config.Logger != nil { + tx.Config.Logger = config.Logger + } + + if config.NowFunc != nil { + tx.Config.NowFunc = config.NowFunc + } + + return tx +} + +// WithContext change current instance db's context to ctx +func (db *DB) WithContext(ctx context.Context) *DB { + return db.Session(&Session{Context: ctx}) +} + +// Debug start debug mode +func (db *DB) Debug() (tx *DB) { + return db.Session(&Session{ + Logger: db.Logger.LogMode(logger.Info), + }) +} + +// Set store value with key into current db instance's context +func (db *DB) Set(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(key, value) + return tx +} + +// Get get value with key from current db instance's context +func (db *DB) Get(key string) (interface{}, bool) { + return db.Statement.Settings.Load(key) +} + +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) +} + +// Callback returns callback manager +func (db *DB) Callback() *callbacks { + return db.callbacks +} + +// AddError add error to db +func (db *DB) AddError(err error) error { + if db.Error == nil { + db.Error = err + } else if err != nil { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } + return db.Error +} + +// DB returns `*sql.DB` +func (db *DB) DB() (*sql.DB, error) { + connPool := db.ConnPool + + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := connPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + +func (db *DB) getInstance() *DB { + if db.clone > 0 { + tx := &DB{Config: db.Config, Error: db.Error} + + if db.clone == 1 { + // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + } + } else { + // with clone statement + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } + + return tx + } + + return db +} + +func Expr(expr string, args ...interface{}) clause.Expr { + return clause.Expr{SQL: expr, Vars: args} +} + +func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { + var ( + tx = db.getInstance() + stmt = tx.Statement + modelSchema, joinSchema *schema.Schema + ) + + if err := stmt.Parse(model); err == nil { + modelSchema = stmt.Schema + } else { + return err + } + + if err := stmt.Parse(joinTable); err == nil { + joinSchema = stmt.Schema + } else { + return err + } + + if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { + for _, ref := range relation.References { + if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { + f.DataType = ref.ForeignKey.DataType + f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } + ref.ForeignKey = f + } else { + return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + } + } + + for name, rel := range relation.JoinTable.Relationships.Relations { + if _, ok := joinSchema.Relationships.Relations[name]; !ok { + rel.Schema = joinSchema + joinSchema.Relationships.Relations[name] = rel + } + } + + relation.JoinTable = joinSchema + } else { + return fmt.Errorf("failed to found relation: %v", field) + } + + return nil +} + +func (db *DB) Use(plugin Plugin) error { + name := plugin.Name() + if _, ok := db.Plugins[name]; ok { + return ErrRegistered + } + if err := plugin.Initialize(db); err != nil { + return err + } + db.Plugins[name] = plugin + return nil +} diff --git a/interface.go b/interface.go deleted file mode 100644 index 55128f7f..00000000 --- a/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package gorm - -import "database/sql" - -// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. -type SQLCommon interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -type sqlDb interface { - Begin() (*sql.Tx, error) -} - -type sqlTx interface { - Commit() error - Rollback() error -} diff --git a/interfaces.go b/interfaces.go new file mode 100644 index 00000000..44b2fced --- /dev/null +++ b/interfaces.go @@ -0,0 +1,63 @@ +package gorm + +import ( + "context" + "database/sql" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Dialector GORM database dialector +type Dialector interface { + Name() string + Initialize(*DB) error + Migrator(db *DB) Migrator + DataTypeOf(*schema.Field) string + DefaultValueOf(*schema.Field) clause.Expression + BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) + QuoteTo(clause.Writer, string) + Explain(sql string, vars ...interface{}) string +} + +// Plugin GORM plugin interface +type Plugin interface { + Name() string + Initialize(*DB) error +} + +// ConnPool db conns pool interface +type ConnPool interface { + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// SavePointerDialectorInterface save pointer interface +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + +type TxBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +type TxCommitter interface { + Commit() error + Rollback() error +} + +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr +} + +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} diff --git a/join_table_handler.go b/join_table_handler.go deleted file mode 100644 index b4be6cf9..00000000 --- a/join_table_handler.go +++ /dev/null @@ -1,223 +0,0 @@ -package gorm - -import ( - "errors" - "fmt" - "reflect" - "strings" -) - -// JoinTableHandlerInterface is an interface for how to handle many2many relations -type JoinTableHandlerInterface interface { - // initialize join table handler - Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) - // Table return join table's table name - Table(db *DB) string - // Add create relationship in join table for source and destination - Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error - // Delete delete relationship in join table for sources - Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error - // JoinWith query with `Join` conditions - JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB - // SourceForeignKeys return source foreign keys - SourceForeignKeys() []JoinTableForeignKey - // DestinationForeignKeys return destination foreign keys - DestinationForeignKeys() []JoinTableForeignKey -} - -// JoinTableForeignKey join table foreign key struct -type JoinTableForeignKey struct { - DBName string - AssociationDBName string -} - -// JoinTableSource is a struct that contains model type and foreign keys -type JoinTableSource struct { - ModelType reflect.Type - ForeignKeys []JoinTableForeignKey -} - -// JoinTableHandler default join table handler -type JoinTableHandler struct { - TableName string `sql:"-"` - Source JoinTableSource `sql:"-"` - Destination JoinTableSource `sql:"-"` -} - -// SourceForeignKeys return source foreign keys -func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { - return s.Source.ForeignKeys -} - -// DestinationForeignKeys return destination foreign keys -func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { - return s.Destination.ForeignKeys -} - -// Setup initialize a default join table handler -func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { - s.TableName = tableName - - s.Source = JoinTableSource{ModelType: source} - s.Source.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.ForeignFieldNames { - s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBNames[idx], - AssociationDBName: dbName, - }) - } - - s.Destination = JoinTableSource{ModelType: destination} - s.Destination.ForeignKeys = []JoinTableForeignKey{} - for idx, dbName := range relationship.AssociationForeignFieldNames { - s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBNames[idx], - AssociationDBName: dbName, - }) - } -} - -// Table return join table's table name -func (s JoinTableHandler) Table(db *DB) string { - return s.TableName -} - -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - - for _, source := range sources { - scope := db.NewScope(source) - modelType := scope.GetModelStruct().ModelType - - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } - } - return values -} - -// Add create relationship in join table for source and destination -func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := map[string]interface{}{} - - // getSearchMap() cannot be used here since the source and destination - // model types may be identical - - sourceScope := db.NewScope(source) - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } - - destinationScope := db.NewScope(destination) - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok { - searchMap[foreignKey.DBName] = field.Field.Interface() - } - } - - var assignColumns, binVars, conditions []string - var values []interface{} - for key, value := range searchMap { - assignColumns = append(assignColumns, scope.Quote(key)) - binVars = append(binVars, `?`) - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - for _, value := range values { - values = append(values, value) - } - - quotedTable := scope.Quote(handler.Table(db)) - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", - quotedTable, - strings.Join(assignColumns, ","), - strings.Join(binVars, ","), - scope.Dialect().SelectFromDummyTable(), - quotedTable, - strings.Join(conditions, " AND "), - ) - - return db.Exec(sql, values...).Error -} - -// Delete delete relationship in join table for sources -func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { - var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} - ) - - for key, value := range s.getSearchMap(db, sources...) { - conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) - values = append(values, value) - } - - return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error -} - -// JoinWith query with `Join` conditions -func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { - var ( - scope = db.NewScope(source) - tableName = handler.Table(db) - quotedTableName = scope.Quote(tableName) - joinConditions []string - values []interface{} - ) - - if s.Source.ModelType == scope.GetModelStruct().ModelType { - destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() - for _, foreignKey := range s.Destination.ForeignKeys { - joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) - } - - var foreignDBNames []string - var foreignFieldNames []string - - for _, foreignKey := range s.Source.ForeignKeys { - foreignDBNames = append(foreignDBNames, foreignKey.DBName) - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - foreignFieldNames = append(foreignFieldNames, field.Name) - } - } - - foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value) - - var condString string - if len(foreignFieldValues) > 0 { - var quotedForeignDBNames []string - for _, dbName := range foreignDBNames { - quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName) - } - - condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues)) - - keys := scope.getColumnAsArray(foreignFieldNames, scope.Value) - values = append(values, toQueryValues(keys)) - } else { - condString = fmt.Sprintf("1 <> 1") - } - - return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))). - Where(condString, toQueryValues(foreignFieldValues)...) - } - - db.Error = errors.New("wrong source type for join table handler") - return db -} diff --git a/join_table_test.go b/join_table_test.go deleted file mode 100644 index 6d5f427d..00000000 --- a/join_table_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package gorm_test - -import ( - "fmt" - "strconv" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type Person struct { - Id int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` -} - -type PersonAddress struct { - gorm.JoinTableHandler - PersonID int - AddressID int - DeletedAt *time.Time - CreatedAt time.Time -} - -func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { - foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue())) - associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue())) - if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - }).Update(map[string]interface{}{ - "person_id": foreignPrimaryKey, - "address_id": associationPrimaryKey, - "deleted_at": gorm.Expr("NULL"), - }).RowsAffected; result == 0 { - return db.Create(&PersonAddress{ - PersonID: foreignPrimaryKey, - AddressID: associationPrimaryKey, - }).Error - } - - return nil -} - -func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { - return db.Delete(&PersonAddress{}).Error -} - -func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { - table := pa.Table(db) - return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) -} - -func TestJoinTable(t *testing.T) { - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &Person{Name: "person", Addresses: []*Address{address1, address2}} - DB.Save(person) - - DB.Model(person).Association("Addresses").Delete(address1) - - if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 { - t.Errorf("Should found one address") - } - - if DB.Model(person).Association("Addresses").Count() != 1 { - t.Errorf("Should found one address") - } - - if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 { - t.Errorf("Found two addresses with Unscoped") - } - - if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} - -func TestEmbeddedMany2ManyRelationship(t *testing.T) { - type EmbeddedPerson struct { - ID int - Name string - Addresses []*Address `gorm:"many2many:person_addresses;"` - } - - type NewPerson struct { - EmbeddedPerson - ExternalID uint - } - DB.Exec("drop table person_addresses;") - DB.AutoMigrate(&NewPerson{}) - - address1 := &Address{Address1: "address 1"} - address2 := &Address{Address1: "address 2"} - person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}} - if err := DB.Save(person).Error; err != nil { - t.Errorf("no error should return when save embedded many2many relationship, but got %v", err) - } - - if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil { - t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err) - } - - association := DB.Model(person).Association("Addresses") - if count := association.Count(); count != 1 || association.Error != nil { - t.Errorf("Should found one address, but got %v, error is %v", count, association.Error) - } - - if association.Clear(); association.Count() != 0 { - t.Errorf("Should deleted all addresses") - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 4324a2e4..00000000 --- a/logger.go +++ /dev/null @@ -1,119 +0,0 @@ -package gorm - -import ( - "database/sql/driver" - "fmt" - "log" - "os" - "reflect" - "regexp" - "strconv" - "time" - "unicode" -) - -var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`\?`) - numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) -) - -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true -} - -var LogFormatter = func(values ...interface{}) (messages []interface{}) { - if len(values) > 1 { - var ( - sql string - formattedValues []string - level = values[0] - currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - ) - - messages = []interface{}{source, currentTime} - - if level == "sql" { - // duration - messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) - // sql - - for _, value := range values[4].([]interface{}) { - indirectValue := reflect.Indirect(reflect.ValueOf(value)) - if indirectValue.IsValid() { - value = indirectValue.Interface() - if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) - } else if b, ok := value.([]byte); ok { - if str := string(b); isPrintable(str) { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) - } else { - formattedValues = append(formattedValues, "''") - } - } else if r, ok := value.(driver.Valuer); ok { - if value, err := r.Value(); err == nil && value != nil { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } else { - formattedValues = append(formattedValues, "NULL") - } - } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) - } - } else { - formattedValues = append(formattedValues, "NULL") - } - } - - // differentiate between $n placeholders or else treat like ? - if numericPlaceHolderRegexp.MatchString(values[3].(string)) { - sql = values[3].(string) - for index, value := range formattedValues { - placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) - sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") - } - } else { - formattedValuesLength := len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] - } - } - } - - messages = append(messages, sql) - messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) - } else { - messages = append(messages, "\033[31;1m") - messages = append(messages, values[2:]...) - messages = append(messages, "\033[0m") - } - } - - return -} - -type logger interface { - Print(v ...interface{}) -} - -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} - -// Logger default logger -type Logger struct { - LogWriter -} - -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - logger.Println(LogFormatter(values...)...) -} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..f14748c1 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,187 @@ +package logger + +import ( + "context" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "time" + + "gorm.io/gorm/utils" +) + +var ErrRecordNotFound = errors.New("record not found") + +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" +) + +// LogLevel +type LogLevel int + +const ( + Silent LogLevel = iota + 1 + Error + Warn + Info +) + +// Writer log writer interface +type Writer interface { + Printf(string, ...interface{}) +} + +type Config struct { + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + LogLevel LogLevel +} + +// Interface logger interface +type Interface interface { + LogMode(LogLevel) Interface + Info(context.Context, string, ...interface{}) + Warn(context.Context, string, ...interface{}) + Error(context.Context, string, ...interface{}) + Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) +} + +var ( + Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + Colorful: true, + }) + Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} +) + +func New(writer Writer, config Config) Interface { + var ( + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" + ) + + if config.Colorful { + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + } + + return &logger{ + Writer: writer, + Config: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, + } +} + +type logger struct { + Writer + Config + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string +} + +// LogMode log mode +func (l *logger) LogMode(level LogLevel) Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +// Info print info +func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Warn print warn messages +func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Error print error messages +func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) + } +} + +// Trace print sql message +func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + if l.LogLevel > Silent { + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + } + } +} + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/logger/sql.go b/logger/sql.go new file mode 100644 index 00000000..3d31d23c --- /dev/null +++ b/logger/sql.go @@ -0,0 +1,133 @@ +package logger + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + "gorm.io/gorm/utils" +) + +const ( + tmFmtWithMS = "2006-01-02 15:04:05.999" + tmFmtZero = "0000-00-00 00:00:00" + nullStr = "NULL" +) + +func isPrintable(s []byte) bool { + for _, r := range s { + if !unicode.IsPrint(rune(r)) { + return false + } + } + return true +} + +var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} + +func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { + var convertParams func(interface{}, int) + var vars = make([]string, len(avars)) + + convertParams = func(v interface{}, idx int) { + switch v := v.(type) { + case bool: + vars[idx] = strconv.FormatBool(v) + case time.Time: + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + tmFmtZero + escaper + } else { + vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper + } + } else { + vars[idx] = nullStr + } + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + r, _ := v.Value() + convertParams(r, idx) + } else { + vars[idx] = nullStr + } + case fmt.Stringer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = nullStr + } + case []byte: + if isPrintable(v) { + vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + } else { + vars[idx] = escaper + "" + escaper + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + vars[idx] = utils.ToString(v) + case float64, float32: + vars[idx] = fmt.Sprintf("%.6f", v) + case string: + vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + default: + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { + vars[idx] = nullStr + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) + } else { + for _, t := range convertibleTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return + } + } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + } + } + } + + for idx, v := range avars { + convertParams(v, idx) + } + + if numericPlaceholder == nil { + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) + } + + sql = newSQL.String() + } else { + sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") + for idx, v := range vars { + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) + } + } + + return sql +} diff --git a/logger/sql_test.go b/logger/sql_test.go new file mode 100644 index 00000000..71aa841a --- /dev/null +++ b/logger/sql_test.go @@ -0,0 +1,105 @@ +package logger_test + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "regexp" + "strings" + "testing" + + "github.com/jinzhu/now" + "gorm.io/gorm/logger" +) + +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper +} + +func TestExplainSQL(t *testing.T) { + type role string + type password []byte + var ( + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + ) + + results := []struct { + SQL string + NumericRegexp *regexp.Regexp + Vars []interface{} + Result string + }{ + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", + NumericRegexp: regexp.MustCompile(`@p(\d+)`), + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", + NumericRegexp: regexp.MustCompile(`\$(\d+)`), + Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", + NumericRegexp: regexp.MustCompile(`@p(\d+)`), + Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + } + + for idx, r := range results { + if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { + t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) + } + } +} diff --git a/main.go b/main.go deleted file mode 100644 index 16fa0b79..00000000 --- a/main.go +++ /dev/null @@ -1,753 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strings" - "time" -) - -// DB contains information for current db connection -type DB struct { - Value interface{} - Error error - RowsAffected int64 - - // single db - db SQLCommon - blockGlobalUpdate bool - logMode int - logger logger - search *search - values map[string]interface{} - - // global db - parent *DB - callbacks *Callback - dialect Dialect - singularTable bool -} - -// Open initialize a new db connection, need to import driver first, e.g: -// -// import _ "github.com/go-sql-driver/mysql" -// func main() { -// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// } -// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with -// import _ "github.com/jinzhu/gorm/dialects/mysql" -// // import _ "github.com/jinzhu/gorm/dialects/postgres" -// // import _ "github.com/jinzhu/gorm/dialects/sqlite" -// // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (db *DB, err error) { - if len(args) == 0 { - err = errors.New("invalid database source") - return nil, err - } - var source string - var dbSQL SQLCommon - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case SQLCommon: - dbSQL = value - } - - db = &DB{ - db: dbSQL, - logger: defaultLogger, - values: map[string]interface{}{}, - callbacks: DefaultCallback, - dialect: newDialect(dialect, dbSQL), - } - db.parent = db - if err != nil { - return - } - // Send a ping to make sure the database connection is alive. - if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { - d.Close() - } - } - return -} - -// New clone a new db connection without search conditions -func (s *DB) New() *DB { - clone := s.clone() - clone.search = nil - clone.Value = nil - return clone -} - -type closer interface { - Close() error -} - -// Close close current db connection. If database connection is not an io.Closer, returns an error. -func (s *DB) Close() error { - if db, ok := s.parent.db.(closer); ok { - return db.Close() - } - return errors.New("can't close current db") -} - -// DB get `*sql.DB` from current connection -// If the underlying database connection is not a *sql.DB, returns nil -func (s *DB) DB() *sql.DB { - db, _ := s.db.(*sql.DB) - return db -} - -// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() SQLCommon { - return s.db -} - -// Dialect get dialect -func (s *DB) Dialect() Dialect { - return s.parent.dialect -} - -// Callback return `Callbacks` container, you could add/change/delete callbacks with it -// db.Callback().Create().Register("update_created_at", updateCreated) -// Refer https://jinzhu.github.io/gorm/development.html#callbacks -func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() - return s.parent.callbacks -} - -// SetLogger replace default logger -func (s *DB) SetLogger(log logger) { - s.logger = log -} - -// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs -func (s *DB) LogMode(enable bool) *DB { - if enable { - s.logMode = 2 - } else { - s.logMode = 1 - } - return s -} - -// BlockGlobalUpdate if true, generates an error on update/delete without where clause. -// This is to prevent eventual error with empty objects updates/deletions -func (s *DB) BlockGlobalUpdate(enable bool) *DB { - s.blockGlobalUpdate = enable - return s -} - -// HasBlockGlobalUpdate return state of block -func (s *DB) HasBlockGlobalUpdate() bool { - return s.blockGlobalUpdate -} - -// SingularTable use singular table by default -func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() - s.parent.singularTable = enable -} - -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} -} - -// QueryExpr returns the query as expr object -func (s *DB) QueryExpr() *expr { - scope := s.NewScope(s.Value) - scope.InstanceSet("skip_bindvar", true) - scope.prepareQuerySQL() - - return Expr(scope.SQL, scope.SQLVars...) -} - -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query -func (s *DB) Where(query interface{}, args ...interface{}) *DB { - return s.clone().search.Where(query, args...).db -} - -// Or filter records that match before conditions or this one, similar to `Where` -func (s *DB) Or(query interface{}, args ...interface{}) *DB { - return s.clone().search.Or(query, args...).db -} - -// Not filter records that don't match current conditions, similar to `Where` -func (s *DB) Not(query interface{}, args ...interface{}) *DB { - return s.clone().search.Not(query, args...).db -} - -// Limit specify the number of records to be retrieved -func (s *DB) Limit(limit interface{}) *DB { - return s.clone().search.Limit(limit).db -} - -// Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset interface{}) *DB { - return s.clone().search.Offset(offset).db -} - -// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -// db.Order("name DESC") -// db.Order("name DESC", true) // reorder -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression -func (s *DB) Order(value interface{}, reorder ...bool) *DB { - return s.clone().search.Order(value, reorder...).db -} - -// Select specify fields that you want to retrieve from database when querying, by default, will select all fields; -// When creating/updating, specify fields that you want to save to database -func (s *DB) Select(query interface{}, args ...interface{}) *DB { - return s.clone().search.Select(query, args...).db -} - -// Omit specify fields that you want to ignore when saving to database for creating, updating -func (s *DB) Omit(columns ...string) *DB { - return s.clone().search.Omit(columns...).db -} - -// Group specify the group method on the find -func (s *DB) Group(query string) *DB { - return s.clone().search.Group(query).db -} - -// Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query interface{}, values ...interface{}) *DB { - return s.clone().search.Having(query, values...).db -} - -// Joins specify Joins conditions -// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) -func (s *DB) Joins(query string, args ...interface{}) *DB { - return s.clone().search.Joins(query, args...).db -} - -// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically -// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { -// return db.Where("amount > ?", 1000) -// } -// -// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { -// return func (db *gorm.DB) *gorm.DB { -// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) -// } -// } -// -// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/crud.html#scopes -func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { - for _, f := range funcs { - s = f(s) - } - return s -} - -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete -func (s *DB) Unscoped() *DB { - return s.clone().search.unscoped().db -} - -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Attrs(attrs ...interface{}) *DB { - return s.clone().search.Attrs(attrs...).db -} - -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) Assign(attrs ...interface{}) *DB { - return s.clone().search.Assign(attrs...).db -} - -// First find first record that match given conditions, order by primary key -func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "ASC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Last find last record that match given conditions, order by primary key -func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) - newScope.Search.Limit(1) - return newScope.Set("gorm:order_by_primary_key", "DESC"). - inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Find find records that match given conditions -func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db -} - -// Scan scan value to a struct -func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db -} - -// Row return `*sql.Row` with given conditions -func (s *DB) Row() *sql.Row { - return s.NewScope(s.Value).row() -} - -// Rows return `*sql.Rows` with given conditions -func (s *DB) Rows() (*sql.Rows, error) { - return s.NewScope(s.Value).rows() -} - -// ScanRows scan `*sql.Rows` to give struct -func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { - var ( - clone = s.clone() - scope = clone.NewScope(result) - columns, err = rows.Columns() - ) - - if clone.AddError(err) == nil { - scope.scan(rows, columns, scope.Fields()) - } - - return clone.Error -} - -// Pluck used to query single column from a model as a map -// var ages []int64 -// db.Find(&users).Pluck("age", &ages) -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -// Count get how many records for a model -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - -// Related get related associations -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorinit -func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := c.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - c.NewScope(out).inlineCondition(where...).initialize() - } else { - c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs) - } - return c -} - -// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/crud.html#firstorcreate -func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { - c := s.clone() - if result := s.First(out, where...); result.Error != nil { - if !result.RecordNotFound() { - return result - } - return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db - } else if len(c.search.assignAttrs) > 0 { - return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db - } - return c -} - -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Update(attrs ...interface{}) *DB { - return s.Updates(toSearchableMap(attrs...), true) -} - -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumn(attrs ...interface{}) *DB { - return s.UpdateColumns(toSearchableMap(attrs...)) -} - -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update -func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). - Set("gorm:update_column", true). - Set("gorm:save_associations", false). - InstanceSet("gorm:update_interface", values). - callCallbacks(s.parent.callbacks.updates).db -} - -// Save update value in database, if the value doesn't have primary key, will insert it -func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) - if !scope.PrimaryKeyZero() { - newDB := scope.callCallbacks(s.parent.callbacks.updates).db - if newDB.Error == nil && newDB.RowsAffected == 0 { - return s.New().FirstOrCreate(value) - } - return newDB - } - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Create insert the value into database -func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) - return scope.callCallbacks(s.parent.callbacks.creates).db -} - -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition -func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db -} - -// Raw use raw sql as conditions, won't run it unless invoked by other methods -// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result) -func (s *DB) Raw(sql string, values ...interface{}) *DB { - return s.clone().search.Raw(true).Where(sql, values...).db -} - -// Exec execute raw sql -func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) - generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") - scope.Raw(generatedSQL) - return scope.Exec().db -} - -// Model specify the model you would like to run db operations -// // update all users's name to `hello` -// db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` -// db.Model(&user).Update("name", "hello") -func (s *DB) Model(value interface{}) *DB { - c := s.clone() - c.Value = value - return c -} - -// Table specify the table you would like to run db operations -func (s *DB) Table(name string) *DB { - clone := s.clone() - clone.search.Table(name) - clone.Value = nil - return clone -} - -// Debug start debug mode -func (s *DB) Debug() *DB { - return s.clone().LogMode(true) -} - -// Begin begin a transaction -func (s *DB) Begin() *DB { - c := s.clone() - if db, ok := c.db.(sqlDb); ok && db != nil { - tx, err := db.Begin() - c.db = interface{}(tx).(SQLCommon) - c.AddError(err) - } else { - c.AddError(ErrCantStartTransaction) - } - return c -} - -// Commit commit a transaction -func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { - s.AddError(db.Commit()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// Rollback rollback a transaction -func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok && db != nil { - s.AddError(db.Rollback()) - } else { - s.AddError(ErrInvalidTransaction) - } - return s -} - -// NewRecord check if value's primary key is blank -func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() -} - -// RecordNotFound check if returning ErrRecordNotFound error -func (s *DB) RecordNotFound() bool { - for _, err := range s.GetErrors() { - if err == ErrRecordNotFound { - return true - } - } - return false -} - -// CreateTable create table for models -func (s *DB) CreateTable(models ...interface{}) *DB { - db := s.Unscoped() - for _, model := range models { - db = db.NewScope(model).createTable().db - } - return db -} - -// DropTable drop table for models -func (s *DB) DropTable(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if tableName, ok := value.(string); ok { - db = db.Table(tableName) - } - - db = db.NewScope(value).dropTable().db - } - return db -} - -// DropTableIfExists drop table if it is exist -func (s *DB) DropTableIfExists(values ...interface{}) *DB { - db := s.clone() - for _, value := range values { - if s.HasTable(value) { - db.AddError(s.DropTable(value).Error) - } - } - return db -} - -// HasTable check has table or not -func (s *DB) HasTable(value interface{}) bool { - var ( - scope = s.clone().NewScope(value) - tableName string - ) - - if name, ok := value.(string); ok { - tableName = name - } else { - tableName = scope.TableName() - } - - has := scope.Dialect().HasTable(tableName) - s.AddError(scope.db.Error) - return has -} - -// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data -func (s *DB) AutoMigrate(values ...interface{}) *DB { - db := s.Unscoped() - for _, value := range values { - db = db.NewScope(value).autoMigrate().db - } - return db -} - -// ModifyColumn modify column to type -func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) - scope.modifyColumn(column, typ) - return scope.db -} - -// DropColumn drop a column -func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) - scope.dropColumn(column) - return scope.db -} - -// AddIndex add index for columns with given name -func (s *DB) AddIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(false, indexName, columns...) - return scope.db -} - -// AddUniqueIndex add unique index for columns with given name -func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { - scope := s.Unscoped().NewScope(s.Value) - scope.addIndex(true, indexName, columns...) - return scope.db -} - -// RemoveIndex remove index with name -func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) - scope.removeIndex(indexName) - return scope.db -} - -// AddForeignKey Add foreign key to the given scope, e.g: -// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") -func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) - scope.addForeignKey(field, dest, onDelete, onUpdate) - return scope.db -} - -// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode -func (s *DB) Association(column string) *Association { - var err error - var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) - - if primaryField := scope.PrimaryField(); primaryField.IsBlank { - err = errors.New("primary key can't be nil") - } else { - if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { - err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) - } else { - return &Association{scope: scope, column: column, field: field} - } - } else { - err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) - } - } - - return &Association{Error: err} -} - -// Preload preload associations with given conditions -// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) -func (s *DB) Preload(column string, conditions ...interface{}) *DB { - return s.clone().search.Preload(column, conditions...).db -} - -// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting -func (s *DB) Set(name string, value interface{}) *DB { - return s.clone().InstantSet(name, value) -} - -// InstantSet instant set setting, will affect current db -func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value - return s -} - -// Get get setting by name -func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] - return -} - -// SetJoinTableHandler set a model's join table handler for a relation -func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - scope := s.NewScope(source) - for _, field := range scope.GetModelStruct().StructFields { - if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - source := (&Scope{Value: source}).GetModelStruct().ModelType - destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType - handler.Setup(field.Relationship, many2many, source, destination) - field.Relationship.JoinTableHandler = handler - if table := handler.Table(s); scope.Dialect().HasTable(table) { - s.Table(table).AutoMigrate(handler) - } - } - } - } -} - -// AddError add error to the db -func (s *DB) AddError(err error) error { - if err != nil { - if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) - } else { - s.log(err) - } - - errors := Errors(s.GetErrors()) - errors = errors.Add(err) - if len(errors) > 1 { - err = errors - } - } - - s.Error = err - } - return err -} - -// GetErrors get happened errors from the db -func (s *DB) GetErrors() []error { - if errs, ok := s.Error.(Errors); ok { - return errs - } else if s.Error != nil { - return []error{s.Error} - } - return []error{} -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For DB -//////////////////////////////////////////////////////////////////////////////// - -func (s *DB) clone() *DB { - db := &DB{ - db: s.db, - parent: s.parent, - logger: s.logger, - logMode: s.logMode, - values: map[string]interface{}{}, - Value: s.Value, - Error: s.Error, - blockGlobalUpdate: s.blockGlobalUpdate, - } - - for key, value := range s.values { - db.values[key] = value - } - - if s.search == nil { - db.search = &search{limit: -1, offset: -1} - } else { - db.search = s.search.clone() - } - - db.search.db = db - return db -} - -func (s *DB) print(v ...interface{}) { - s.logger.Print(v...) -} - -func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { - s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) - } -} - -func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) - } -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 34f96a86..00000000 --- a/main_test.go +++ /dev/null @@ -1,910 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" - - "github.com/erikstmartin/go-testdb" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mssql" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/jinzhu/gorm/dialects/postgres" - _ "github.com/jinzhu/gorm/dialects/sqlite" - "github.com/jinzhu/now" -) - -var ( - DB *gorm.DB - t1, t2, t3, t4, t5 time.Time -) - -func init() { - var err error - - if DB, err = OpenTestConnection(); err != nil { - panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err)) - } - - runMigration() -} - -func OpenTestConnection() (db *gorm.DB, err error) { - switch os.Getenv("GORM_DIALECT") { - case "mysql": - // CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm'; - // CREATE DATABASE gorm; - // GRANT ALL ON gorm.* TO 'gorm'@'localhost'; - fmt.Println("testing mysql...") - dbhost := os.Getenv("GORM_DBADDRESS") - if dbhost != "" { - dbhost = fmt.Sprintf("tcp(%v)", dbhost) - } - db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost)) - case "postgres": - fmt.Println("testing postgres...") - dbhost := os.Getenv("GORM_DBHOST") - if dbhost != "" { - dbhost = fmt.Sprintf("host=%v ", dbhost) - } - db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost)) - case "foundation": - fmt.Println("testing foundation...") - db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") - case "mssql": - // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; - // CREATE DATABASE gorm; - // USE gorm; - // CREATE USER gorm FROM LOGIN gorm; - // sp_changedbowner 'gorm'; - fmt.Println("testing mssql...") - db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm") - default: - fmt.Println("testing sqlite3...") - db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db")) - } - - // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)}) - // db.SetLogger(log.New(os.Stdout, "\r\n", 0)) - if os.Getenv("DEBUG") == "true" { - db.LogMode(true) - } - - db.DB().SetMaxIdleConns(10) - - return -} - -func TestStringPrimaryKey(t *testing.T) { - type UUIDStruct struct { - ID string `gorm:"primary_key"` - Name string - } - DB.DropTable(&UUIDStruct{}) - DB.AutoMigrate(&UUIDStruct{}) - - data := UUIDStruct{ID: "uuid", Name: "hello"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" { - t.Errorf("string primary key should not be populated") - } - - data = UUIDStruct{ID: "uuid", Name: "hello world"} - if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" { - t.Errorf("string primary key should not be populated") - } -} - -func TestExceptionsWithInvalidSql(t *testing.T) { - var columns []string - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - var count1, count2 int64 - DB.Model(&User{}).Count(&count1) - if count1 <= 0 { - t.Errorf("Should find some users") - } - - if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { - t.Errorf("Should got error with invalid SQL") - } - - DB.Model(&User{}).Count(&count2) - if count1 != count2 { - t.Errorf("No user should not be deleted by invalid SQL") - } -} - -func TestSetTable(t *testing.T) { - DB.Create(getPreparedUser("pluck_user1", "pluck_user")) - DB.Create(getPreparedUser("pluck_user2", "pluck_user")) - DB.Create(getPreparedUser("pluck_user3", "pluck_user")) - - if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil { - t.Error("No errors should happen if set table for pluck", err) - } - - var users []User - if DB.Table("users").Find(&[]User{}).Error != nil { - t.Errorf("No errors should happen if set table for find") - } - - if DB.Table("invalid_table").Find(&users).Error == nil { - t.Errorf("Should got error when table is set to an invalid table") - } - - DB.Exec("drop table deleted_users;") - if DB.Table("deleted_users").CreateTable(&User{}).Error != nil { - t.Errorf("Create table with specified table") - } - - DB.Table("deleted_users").Save(&User{Name: "DeletedUser"}) - - var deletedUsers []User - DB.Table("deleted_users").Find(&deletedUsers) - if len(deletedUsers) != 1 { - t.Errorf("Query from specified table") - } - - DB.Save(getPreparedUser("normal_user", "reset_table")) - DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table")) - var user1, user2, user3 User - DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3) - if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") { - t.Errorf("unset specified table with blank string") - } -} - -type Order struct { -} - -type Cart struct { -} - -func (c Cart) TableName() string { - return "shopping_cart" -} - -func TestHasTable(t *testing.T) { - type Foo struct { - Id int - Stuff string - } - DB.DropTable(&Foo{}) - - // Table should not exist at this point, HasTable should return false - if ok := DB.HasTable("foos"); ok { - t.Errorf("Table should not exist, but does") - } - if ok := DB.HasTable(&Foo{}); ok { - t.Errorf("Table should not exist, but does") - } - - // We create the table - if err := DB.CreateTable(&Foo{}).Error; err != nil { - t.Errorf("Table should be created") - } - - // And now it should exits, and HasTable should return true - if ok := DB.HasTable("foos"); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } - if ok := DB.HasTable(&Foo{}); !ok { - t.Errorf("Table should exist, but HasTable informs it does not") - } -} - -func TestTableName(t *testing.T) { - DB := DB.Model("") - if DB.NewScope(Order{}).TableName() != "orders" { - t.Errorf("Order's table name should be orders") - } - - if DB.NewScope(&Order{}).TableName() != "orders" { - t.Errorf("&Order's table name should be orders") - } - - if DB.NewScope([]Order{}).TableName() != "orders" { - t.Errorf("[]Order's table name should be orders") - } - - if DB.NewScope(&[]Order{}).TableName() != "orders" { - t.Errorf("&[]Order's table name should be orders") - } - - DB.SingularTable(true) - if DB.NewScope(Order{}).TableName() != "order" { - t.Errorf("Order's singular table name should be order") - } - - if DB.NewScope(&Order{}).TableName() != "order" { - t.Errorf("&Order's singular table name should be order") - } - - if DB.NewScope([]Order{}).TableName() != "order" { - t.Errorf("[]Order's singular table name should be order") - } - - if DB.NewScope(&[]Order{}).TableName() != "order" { - t.Errorf("&[]Order's singular table name should be order") - } - - if DB.NewScope(&Cart{}).TableName() != "shopping_cart" { - t.Errorf("&Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(Cart{}).TableName() != "shopping_cart" { - t.Errorf("Cart's singular table name should be shopping_cart") - } - - if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" { - t.Errorf("&[]Cart's singular table name should be shopping_cart") - } - - if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { - t.Errorf("[]Cart's singular table name should be shopping_cart") - } - DB.SingularTable(false) -} - -func TestNullValues(t *testing.T) { - DB.DropTable(&NullValue{}) - DB.AutoMigrate(&NullValue{}) - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: true}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv NullValue - DB.First(&nv, "name = ?", "hello") - - if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-2", Valid: true}, - Gender: &sql.NullString{String: "F", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err != nil { - t.Errorf("Not error should raise when test null value") - } - - var nv2 NullValue - DB.First(&nv2, "name = ?", "hello-2") - if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false { - t.Errorf("Should be able to fetch null value") - } - - if err := DB.Save(&NullValue{ - Name: sql.NullString{String: "hello-3", Valid: false}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: false}, - Male: sql.NullBool{Bool: true, Valid: true}, - Height: sql.NullFloat64{Float64: 100.11, Valid: true}, - AddedAt: NullTime{Time: time.Now(), Valid: false}, - }).Error; err == nil { - t.Errorf("Can't save because of name can't be null") - } -} - -func TestNullValuesWithFirstOrCreate(t *testing.T) { - var nv1 = NullValue{ - Name: sql.NullString{String: "first_or_create", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - } - - var nv2 NullValue - result := DB.Where(nv1).FirstOrCreate(&nv2) - - if result.RowsAffected != 1 { - t.Errorf("RowsAffected should be 1 after create some record") - } - - if result.Error != nil { - t.Errorf("Should not raise any error, but got %v", result.Error) - } - - if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" { - t.Errorf("first or create with nullvalues") - } - - if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil { - t.Errorf("Should not raise any error, but got %v", err) - } - - if nv2.Age.Int64 != 18 { - t.Errorf("should update age to 18") - } -} - -func TestTransaction(t *testing.T) { - tx := DB.Begin() - u := User{Name: "transcation"} - if err := tx.Save(&u).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record") - } - - if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") - } - - tx.Rollback() - - if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback") - } - - tx2 := DB.Begin() - u2 := User{Name: "transcation-2"} - if err := tx2.Save(&u2).Error; err != nil { - t.Errorf("No error should raise") - } - - if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record") - } - - tx2.Commit() - - if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record") - } -} - -func TestRow(t *testing.T) { - user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() - var age int64 - row.Scan(&age) - if age != 10 { - t.Errorf("Scan with Row") - } -} - -func TestRows(t *testing.T) { - user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - count := 0 - for rows.Next() { - var name string - var age int64 - rows.Scan(&name, &age) - count++ - } - - if count != 2 { - t.Errorf("Should found two records") - } -} - -func TestScanRows(t *testing.T) { - user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() - if err != nil { - t.Errorf("Not error should happen, got %v", err) - } - - type Result struct { - Name string - Age int - } - - var results []Result - for rows.Next() { - var result Result - if err := DB.ScanRows(rows, &result); err != nil { - t.Errorf("should get no error, but got %v", err) - } - results = append(results, result) - } - - if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") - } -} - -func TestScan(t *testing.T) { - user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Age int - } - - var res result - DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res) - if res.Name != user3.Name { - t.Errorf("Scan into struct should work") - } - - var doubleAgeRes = &result{} - if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil { - t.Errorf("Scan to pointer of pointer") - } - if doubleAgeRes.Age != res.Age*2 { - t.Errorf("Scan double age as age") - } - - var ress []result - DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Scan into struct map") - } -} - -func TestRaw(t *testing.T) { - user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - type result struct { - Name string - Email string - } - - var ress []result - DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress) - if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name { - t.Errorf("Raw with scan") - } - - rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() - count := 0 - for rows.Next() { - count++ - } - if count != 1 { - t.Errorf("Raw with Rows should find one record with name 3") - } - - DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name}) - if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { - t.Error("Raw sql to update records") - } -} - -func TestGroup(t *testing.T) { - rows, err := DB.Select("name").Table("users").Group("name").Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - rows.Scan(&name) - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestJoins(t *testing.T) { - var user = User{ - Name: "joins", - CreditCard: CreditCard{Number: "411111111111"}, - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var users1 []User - DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1) - if len(users1) != 2 { - t.Errorf("should find two users using left join") - } - - var users2 []User - DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2) - if len(users2) != 1 { - t.Errorf("should find one users using left join with conditions") - } - - var users3 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3) - if len(users3) != 1 { - t.Errorf("should find one users using multiple left join conditions") - } - - var users4 []User - DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4) - if len(users4) != 0 { - t.Errorf("should find no user when searching with unexisting credit card") - } - - var users5 []User - db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5) - if db5.Error != nil { - t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) - } -} - -func TestJoinsWithSelect(t *testing.T) { - type result struct { - Name string - Email string - } - - user := User{ - Name: "joins_with_select", - Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}}, - } - DB.Save(&user) - - var results []result - DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results) - if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" { - t.Errorf("Should find all two emails with Join select") - } -} - -func TestHaving(t *testing.T) { - rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() - - if err == nil { - defer rows.Close() - for rows.Next() { - var name string - var total int64 - rows.Scan(&name, &total) - - if name == "2" && total != 1 { - t.Errorf("Should have one user having name 2") - } - if name == "3" && total != 2 { - t.Errorf("Should have two users having name 3") - } - } - } else { - t.Errorf("Should not raise any error") - } -} - -func TestQueryBuilderSubselectInWhere(t *testing.T) { - user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("*").Where("name IN (?)", DB. - Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 4 { - t.Errorf("Four users should be found, instead found %d", len(users)) - } - - DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB. - Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users) - - if len(users) != 2 { - t.Errorf("Two users should be found, instead found %d", len(users)) - } -} - -func TestQueryBuilderSubselectInHaving(t *testing.T) { - user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64} - DB.Save(&user) - user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128} - DB.Save(&user) - - var users []User - DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB. - Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users) - - if len(users) != 1 { - t.Errorf("Two user group should be found, instead found %d", len(users)) - } -} - -func DialectHasTzSupport() bool { - // NB: mssql and FoundationDB do not support time zones. - if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" { - return false - } - return true -} - -func TestTimeWithZone(t *testing.T) { - var format = "2006-01-02 15:04:05 -0700" - var times []time.Time - GMT8, _ := time.LoadLocation("Asia/Shanghai") - times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8)) - times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)) - - for index, vtime := range times { - name := "time_with_zone_" + strconv.Itoa(index) - user := User{Name: name, Birthday: &vtime} - - if !DialectHasTzSupport() { - // If our driver dialect doesn't support TZ's, just use UTC for everything here. - utcBirthday := user.Birthday.UTC() - user.Birthday = &utcBirthday - } - - DB.Save(&user) - expectedBirthday := "2013-02-18 17:51:49 +0000" - foundBirthday := user.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - var findUser, findUser2, findUser3 User - DB.First(&findUser, "name = ?", name) - foundBirthday = findUser.Birthday.UTC().Format(format) - if foundBirthday != expectedBirthday { - t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) - } - - if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { - t.Errorf("User should be found") - } - - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { - t.Errorf("User should not be found") - } - } -} - -func TestHstore(t *testing.T) { - type Details struct { - Id int64 - Bulk postgres.Hstore - } - - if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { - t.Skip() - } - - if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil { - fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m") - panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err)) - } - - DB.Exec("drop table details") - - if err := DB.CreateTable(&Details{}).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } - - bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait" - bulk := map[string]*string{ - "bankAccountId": &bankAccountId, - "phoneNumber": &phoneNumber, - "opinion": &opinion, - } - d := Details{Bulk: bulk} - DB.Save(&d) - - var d2 Details - if err := DB.First(&d2).Error; err != nil { - t.Errorf("Got error when tried to fetch details: %+v", err) - } - - for k := range bulk { - if r, ok := d2.Bulk[k]; ok { - if res, _ := bulk[k]; *res != *r { - t.Errorf("Details should be equal") - } - } else { - t.Errorf("Details should be existed") - } - } -} - -func TestSetAndGet(t *testing.T) { - if value, ok := DB.Set("hello", "world").Get("hello"); !ok { - t.Errorf("Should be able to get setting after set") - } else { - if value.(string) != "world" { - t.Errorf("Setted value should not be changed") - } - } - - if _, ok := DB.Get("non_existing"); ok { - t.Errorf("Get non existing key should return error") - } -} - -func TestCompatibilityMode(t *testing.T) { - DB, _ := gorm.Open("testdb", "") - testdb.SetQueryFunc(func(query string) (driver.Rows, error) { - columns := []string{"id", "name", "age"} - result := ` - 1,Tim,20 - 2,Joe,25 - 3,Bob,30 - ` - return testdb.RowsFromCSVString(columns, result), nil - }) - - var users []User - DB.Find(&users) - if (users[0].Name != "Tim") || len(users) != 3 { - t.Errorf("Unexcepted result returned") - } -} - -func TestOpenExistingDB(t *testing.T) { - DB.Save(&User{Name: "jnfeinstein"}) - dialect := os.Getenv("GORM_DIALECT") - - db, err := gorm.Open(dialect, DB.DB()) - if err != nil { - t.Errorf("Should have wrapped the existing DB connection") - } - - var user User - if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound { - t.Errorf("Should have found existing record") - } -} - -func TestDdlErrors(t *testing.T) { - var err error - - if err = DB.Close(); err != nil { - t.Errorf("Closing DDL test db connection err=%s", err) - } - defer func() { - // Reopen DB connection. - if DB, err = OpenTestConnection(); err != nil { - t.Fatalf("Failed re-opening db connection: %s", err) - } - }() - - if err := DB.Find(&User{}).Error; err == nil { - t.Errorf("Expected operation on closed db to produce an error, but err was nil") - } -} - -func TestOpenWithOneParameter(t *testing.T) { - db, err := gorm.Open("dialect") - if db != nil { - t.Error("Open with one parameter returned non nil for db") - } - if err == nil { - t.Error("Open with one parameter returned err as nil") - } -} - -func TestBlockGlobalUpdate(t *testing.T) { - db := DB.New() - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err := db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err != nil { - t.Error("Unexpected error on global update") - } - - err = db.Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on global delete") - } - - db.BlockGlobalUpdate(true) - - db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"}) - - err = db.Model(&Toy{}).Update("OwnerType", "Human").Error - if err == nil { - t.Error("Expected error on global update") - } - - err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error - if err != nil { - t.Error("Unxpected error on conditional update") - } - - err = db.Delete(&Toy{}).Error - if err == nil { - t.Error("Expected error on global delete") - } - err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error - if err != nil { - t.Error("Unexpected error on conditional delete") - } -} - -func BenchmarkGorm(b *testing.B) { - b.N = 2000 - for x := 0; x < b.N; x++ { - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.Save(&email) - // Query - DB.First(&EmailWithIdx{}, "email = ?", e) - // Update - DB.Model(&email).UpdateColumn("email", "new-"+e) - // Delete - DB.Delete(&email) - } -} - -func BenchmarkRawSql(b *testing.B) { - DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable") - DB.SetMaxIdleConns(10) - insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id" - querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1" - updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3" - deleteSql := "DELETE FROM orders WHERE id = $1" - - b.N = 2000 - for x := 0; x < b.N; x++ { - var id int64 - e := strconv.Itoa(x) + "benchmark@example.org" - now := time.Now() - email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now} - // Insert - DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id) - // Query - rows, _ := DB.Query(querySql, email.Email) - rows.Close() - // Update - DB.Exec(updateSql, "new-"+e, time.Now(), id) - // Delete - DB.Exec(deleteSql, id) - } -} - -func parseTime(str string) *time.Time { - t := now.New(time.Now().UTC()).MustParse(str) - return &t -} diff --git a/migration_test.go b/migration_test.go deleted file mode 100644 index 3f3a5c8f..00000000 --- a/migration_test.go +++ /dev/null @@ -1,459 +0,0 @@ -package gorm_test - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "os" - "reflect" - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -type User struct { - Id int64 - Age int64 - UserNum Num - Name string `sql:"size:255"` - Email string - Birthday *time.Time // Time - CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically - UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically - Emails []Email // Embedded structs - BillingAddress Address // Embedded struct - BillingAddressID sql.NullInt64 // Embedded struct's foreign key - ShippingAddress Address // Embedded struct - ShippingAddressId int64 // Embedded struct's foreign key - CreditCard CreditCard - Latitude float64 - Languages []Language `gorm:"many2many:user_languages;"` - CompanyID *int - Company Company - Role Role - PasswordHash []byte - IgnoreMe int64 `sql:"-"` - IgnoreStringSlice []string `sql:"-"` - Ignored struct{ Name string } `sql:"-"` - IgnoredPointer *User `sql:"-"` -} - -type NotSoLongTableName struct { - Id int64 - ReallyLongThingID int64 - ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit -} - -type ReallyLongTableNameToTestMySQLNameLengthLimit struct { - Id int64 -} - -type ReallyLongThingThatReferencesShort struct { - Id int64 - ShortID int64 - Short Short -} - -type Short struct { - Id int64 -} - -type CreditCard struct { - ID int8 - Number string - UserId sql.NullInt64 - CreatedAt time.Time `sql:"not null"` - UpdatedAt time.Time - DeletedAt *time.Time `sql:"column:deleted_time"` -} - -type Email struct { - Id int16 - UserId int - Email string `sql:"type:varchar(100);"` - CreatedAt time.Time - UpdatedAt time.Time -} - -type Address struct { - ID int - Address1 string - Address2 string - Post string - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time -} - -type Language struct { - gorm.Model - Name string - Users []User `gorm:"many2many:user_languages;"` -} - -type Product struct { - Id int64 - Code string - Price int64 - CreatedAt time.Time - UpdatedAt time.Time - AfterFindCallTimes int64 - BeforeCreateCallTimes int64 - AfterCreateCallTimes int64 - BeforeUpdateCallTimes int64 - AfterUpdateCallTimes int64 - BeforeSaveCallTimes int64 - AfterSaveCallTimes int64 - BeforeDeleteCallTimes int64 - AfterDeleteCallTimes int64 -} - -type Company struct { - Id int64 - Name string - Owner *User `sql:"-"` -} - -type Role struct { - Name string `gorm:"size:256"` -} - -func (role *Role) Scan(value interface{}) error { - if b, ok := value.([]uint8); ok { - role.Name = string(b) - } else { - role.Name = value.(string) - } - return nil -} - -func (role Role) Value() (driver.Value, error) { - return role.Name, nil -} - -func (role Role) IsAdmin() bool { - return role.Name == "admin" -} - -type Num int64 - -func (i *Num) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - case int64: - *i = Num(s) - default: - return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) - } - return nil -} - -type Animal struct { - Counter uint64 `gorm:"primary_key:yes"` - Name string `sql:"DEFAULT:'galeone'"` - From string //test reserved sql keyword as field name - Age time.Time `sql:"DEFAULT:current_timestamp"` - unexported string // unexported value - CreatedAt time.Time - UpdatedAt time.Time -} - -type JoinTable struct { - From uint64 - To uint64 - Time time.Time `sql:"default: null"` -} - -type Post struct { - Id int64 - CategoryId sql.NullInt64 - MainCategoryId int64 - Title string - Body string - Comments []*Comment - Category Category - MainCategory Category -} - -type Category struct { - gorm.Model - Name string - - Categories []Category - CategoryID *uint -} - -type Comment struct { - gorm.Model - PostId int64 - Content string - Post Post -} - -// Scanner -type NullValue struct { - Id int64 - Name sql.NullString `sql:"not null"` - Gender *sql.NullString `sql:"not null"` - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - AddedAt NullTime -} - -type NullTime struct { - Time time.Time - Valid bool -} - -func (nt *NullTime) Scan(value interface{}) error { - if value == nil { - nt.Valid = false - return nil - } - nt.Time, nt.Valid = value.(time.Time), true - return nil -} - -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} - -func getPreparedUser(name string, role string) *User { - var company Company - DB.Where(Company{Name: role}).FirstOrCreate(&company) - - return &User{ - Name: name, - Age: 20, - Role: Role{role}, - BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)}, - ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)}, - CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)}, - Emails: []Email{ - {Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)}, - }, - Company: company, - Languages: []Language{ - {Name: fmt.Sprintf("lang_1_%v", name)}, - {Name: fmt.Sprintf("lang_2_%v", name)}, - }, - } -} - -func runMigration() { - if err := DB.DropTableIfExists(&User{}).Error; err != nil { - fmt.Printf("Got error when try to delete table users, %+v\n", err) - } - - for _, table := range []string{"animals", "user_languages"} { - DB.Exec(fmt.Sprintf("drop table %v;", table)) - } - - values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}} - for _, value := range values { - DB.DropTable(value) - } - if err := DB.AutoMigrate(values...).Error; err != nil { - panic(fmt.Sprintf("No error should happen when create table, but got %+v", err)) - } -} - -func TestIndexes(t *testing.T) { - if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - scope := DB.NewScope(&Email{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email should have index idx_email_email") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") { - t.Errorf("Email's index idx_email_email should be deleted") - } - - if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil { - t.Errorf("Got error when tried to create index: %+v", err) - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email should have index idx_email_email_and_user_id") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil { - t.Errorf("Should get to create duplicate record when having unique index") - } - - var user = User{Name: "sample_user"} - DB.Save(&user) - if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil { - t.Errorf("Should get no error when append two emails for user") - } - - if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil { - t.Errorf("Should get no duplicated email error when insert duplicated emails for a user") - } - - if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil { - t.Errorf("Got error when tried to remove index: %+v", err) - } - - if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") { - t.Errorf("Email's index idx_email_email_and_user_id should be deleted") - } - - if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil { - t.Errorf("Should be able to create duplicated emails after remove unique index") - } -} - -type EmailWithIdx struct { - Id int64 - UserId int64 - Email string `sql:"index:idx_email_agent"` - UserAgent string `sql:"index:idx_email_agent"` - RegisteredAt *time.Time `sql:"unique_index"` - CreatedAt time.Time - UpdatedAt time.Time -} - -func TestAutoMigration(t *testing.T) { - DB.AutoMigrate(&Address{}) - DB.DropTable(&EmailWithIdx{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - now := time.Now() - DB.Save(&EmailWithIdx{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now}) - - scope := DB.NewScope(&EmailWithIdx{}) - if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_email_with_idxes_registered_at") { - t.Errorf("Failed to create index") - } - - var bigemail EmailWithIdx - DB.First(&bigemail, "user_agent = ?", "pc") - if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() { - t.Error("Big Emails should be saved and fetched correctly") - } -} - -type MultipleIndexes struct { - ID int64 - UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` - Name string `sql:"unique_index:uix_multipleindexes_user_name"` - Email string `sql:"unique_index:,uix_multipleindexes_user_email"` - Other string `sql:"index:,idx_multipleindexes_user_other"` -} - -func TestMultipleIndexes(t *testing.T) { - if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil { - fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err) - } - - DB.AutoMigrate(&MultipleIndexes{}) - if err := DB.AutoMigrate(&EmailWithIdx{}).Error; err != nil { - t.Errorf("Auto Migrate should not raise any error") - } - - DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"}) - - scope := DB.NewScope(&MultipleIndexes{}) - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") { - t.Errorf("Failed to create index") - } - - if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") { - t.Errorf("Failed to create index") - } - - var mutipleIndexes MultipleIndexes - DB.First(&mutipleIndexes, "name = ?", "jinzhu") - if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" { - t.Error("MutipleIndexes should be saved and fetched correctly") - } - - // Check unique constraints - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil { - t.Error("MultipleIndexes unique index failed") - } - - if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil { - t.Error("MultipleIndexes unique index failed") - } -} - -func TestModifyColumnType(t *testing.T) { - dialect := os.Getenv("GORM_DIALECT") - if dialect != "postgres" && - dialect != "mysql" && - dialect != "mssql" { - t.Skip("Skipping this because only postgres, mysql and mssql support altering a column type") - } - - type ModifyColumnType struct { - gorm.Model - Name1 string `gorm:"length:100"` - Name2 string `gorm:"length:200"` - } - DB.DropTable(&ModifyColumnType{}) - DB.CreateTable(&ModifyColumnType{}) - - name2Field, _ := DB.NewScope(&ModifyColumnType{}).FieldByName("Name2") - name2Type := DB.Dialect().DataTypeOf(name2Field.StructField) - - if err := DB.Model(&ModifyColumnType{}).ModifyColumn("name1", name2Type).Error; err != nil { - t.Errorf("No error should happen when ModifyColumn, but got %v", err) - } -} diff --git a/migrator.go b/migrator.go new file mode 100644 index 00000000..7dddcabf --- /dev/null +++ b/migrator.go @@ -0,0 +1,81 @@ +package gorm + +import ( + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +// Migrator returns migrator +func (db *DB) Migrator() Migrator { + tx := db.getInstance() + + // apply scopes to migrator + for len(tx.Statement.scopes) > 0 { + scopes := tx.Statement.scopes + tx.Statement.scopes = nil + for _, scope := range scopes { + tx = scope(tx) + } + } + + return tx.Dialector.Migrator(tx.Session(&Session{})) +} + +// AutoMigrate run auto migration for given models +func (db *DB) AutoMigrate(dst ...interface{}) error { + return db.Migrator().AutoMigrate(dst...) +} + +// ViewOption view option +type ViewOption struct { + Replace bool + CheckOption string + Query *DB +} + +type ColumnType interface { + Name() string + DatabaseTypeName() string + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) +} + +type Migrator interface { + // AutoMigrate + AutoMigrate(dst ...interface{}) error + + // Database + CurrentDatabase() string + FullDataTypeOf(*schema.Field) clause.Expr + + // Tables + CreateTable(dst ...interface{}) error + DropTable(dst ...interface{}) error + HasTable(dst interface{}) bool + RenameTable(oldName, newName interface{}) error + + // Columns + AddColumn(dst interface{}, field string) error + DropColumn(dst interface{}, field string) error + AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + HasColumn(dst interface{}, field string) bool + RenameColumn(dst interface{}, oldName, field string) error + ColumnTypes(dst interface{}) ([]ColumnType, error) + + // Views + CreateView(name string, option ViewOption) error + DropView(name string) error + + // Constraints + CreateConstraint(dst interface{}, name string) error + DropConstraint(dst interface{}, name string) error + HasConstraint(dst interface{}, name string) bool + + // Indexes + CreateIndex(dst interface{}, name string) error + DropIndex(dst interface{}, name string) error + HasIndex(dst interface{}, name string) bool + RenameIndex(dst interface{}, oldName, newName string) error +} diff --git a/migrator/migrator.go b/migrator/migrator.go new file mode 100644 index 00000000..1800ab54 --- /dev/null +++ b/migrator/migrator.go @@ -0,0 +1,758 @@ +package migrator + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +var ( + regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) +) + +// Migrator m struct +type Migrator struct { + Config +} + +// Config schema config +type Config struct { + CreateIndexAfterCreateTable bool + DB *gorm.DB + gorm.Dialector +} + +type GormDataTypeInterface interface { + GormDBDataType(*gorm.DB, *schema.Field) string +} + +func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { + stmt := &gorm.Statement{DB: m.DB} + if m.DB.Statement != nil { + stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr + } + + if table, ok := value.(string); ok { + stmt.Table = table + } else if err := stmt.Parse(value); err != nil { + return err + } + + return fc(stmt) +} + +func (m Migrator) DataTypeOf(field *schema.Field) string { + fieldValue := reflect.New(field.IndirectFieldType) + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { + return dataType + } + } + + return m.Dialector.DataTypeOf(field) +} + +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.Unique { + expr.SQL += " UNIQUE" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + return +} + +// AutoMigrate +func (m Migrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + tx := m.DB.Session(&gorm.Session{}) + if !tx.Migrator().HasTable(value) { + if err := tx.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + + for _, field := range stmt.Schema.FieldsByDBName { + var foundColumn gorm.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column + if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { + return err + } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + + return nil + }); err != nil { + return err + } + } + } + + return nil +} + +func (m Migrator) CreateTable(values ...interface{}) error { + for _, value := range m.ReorderModels(values, false) { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := []interface{}{} + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range stmt.Schema.ParseIndexes() { + if m.CreateIndexAfterCreateTable { + defer func(value interface{}, name string) { + if errr == nil { + errr = tx.Migrator().CreateIndex(value, name) + } + }(value, idx.Name) + } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } + createTableSQL += "INDEX ? ?" + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," + values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK (?)," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + errr = tx.Exec(createTableSQL, values...).Error + return errr + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) DropTable(values ...interface{}) error { + values = m.ReorderModels(values, false) + for i := len(values) - 1; i >= 0; i-- { + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error + }); err != nil { + return err + } + } + return nil +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int64 + + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameTable(oldName, newName interface{}) error { + var oldTable, newTable interface{} + if v, ok := oldName.(string); ok { + oldTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(oldName); err == nil { + oldTable = m.CurrentTable(stmt) + } else { + return err + } + } + + if v, ok := newName.(string); ok { + newTable = clause.Table{Name: v} + } else { + stmt := &gorm.Statement{DB: m.DB} + if err := stmt.Parse(newName); err == nil { + newTable = m.CurrentTable(stmt) + } else { + return err + } + } + + return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error +} + +func (m Migrator) AddColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + if !field.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + ).Error + } + return nil + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, + ).Error + }) +} + +func (m Migrator) AlterColumn(value interface{}, field string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := clause.Expr{SQL: m.DataTypeOf(field)} + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + + } + return fmt.Errorf("failed to look up field with name: %s", field) + }) +} + +func (m Migrator) HasColumn(value interface{}, field string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + name := field + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } + + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } + + return m.DB.Exec( + "ALTER TABLE ? RENAME COLUMN ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + + // Since the following code is frequently called in the for loop, reg optimization is needed here + matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) + matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn && !field.IgnoreMigration { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { + columnTypes = make([]gorm.ColumnType, 0) + err = m.RunWithValue(value, func(stmt *gorm.Statement) error { + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + if err == nil { + defer rows.Close() + rawColumnTypes, err := rows.ColumnTypes() + if err == nil { + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + } + } + return err + }) + return +} + +func (m Migrator) CreateView(name string, option gorm.ViewOption) error { + return gorm.ErrNotImplemented +} + +func (m Migrator) DropView(name string) error { + return gorm.ErrNotImplemented +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { + if stmt.Schema == nil { + return nil, nil, stmt.Table + } + + checkConstraints := stmt.Schema.ParseCheckConstraints() + if chk, ok := checkConstraints[name]; ok { + return nil, &chk, stmt.Table + } + + getTable := func(rel *schema.Relationship) string { + switch rel.Type { + case schema.HasOne, schema.HasMany: + return rel.FieldSchema.Table + case schema.Many2Many: + return rel.JoinTable.Table + } + return stmt.Table + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { + return constraint, nil, getTable(rel) + } + } + + if field := stmt.Schema.LookUpField(name); field != nil { + for _, cc := range checkConstraints { + if cc.Field == field { + return nil, &cc, stmt.Table + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { + return constraint, nil, getTable(rel) + } + } + } + + return nil, nil, stmt.Schema.Table +} + +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if chk != nil { + return m.DB.Exec( + "ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)", + m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}, + ).Error + } + + if constraint != nil { + var vars = []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr + } + sql, values := buildConstraint(constraint) + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error + } + + return nil + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error + }) +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", + currentDatabase, table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } else if opt.Length > 0 { + str += fmt.Sprintf("(%d)", opt.Length) + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +type BuildIndexOptionsInterface interface { + BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + currentDatabase := m.DB.Migrator().CurrentDatabase() + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Raw( + "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", + currentDatabase, stmt.Table, name, + ).Row().Scan(&count) + }) + + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Exec( + "ALTER TABLE ? RENAME INDEX ? TO ?", + m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, + ).Error + }) +} + +func (m Migrator) CurrentDatabase() (name string) { + m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) + return +} + +// ReorderModels reorder models according to constraint dependencies +func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { + type Dependency struct { + *gorm.Statement + Depends []*schema.Schema + } + + var ( + modelNames, orderedModelNames []string + orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} + valuesMap = map[string]Dependency{} + insertIntoOrderedList func(name string) + parseDependence func(value interface{}, addToList bool) + ) + + parseDependence = func(value interface{}, addToList bool) { + dep := Dependency{ + Statement: &gorm.Statement{DB: m.DB, Dest: value}, + } + beDependedOn := map[*schema.Schema]bool{} + if err := dep.Parse(value); err != nil { + m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) + } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true + + for _, rel := range dep.Schema.Relationships.Relations { + if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { + dep.Depends = append(dep.Depends, c.ReferenceSchema) + } + + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + + if rel.JoinTable != nil { + // append join value + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } + parseDependence(joinValue, autoAdd) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) + } + } + + valuesMap[dep.Schema.Table] = dep + + if addToList { + modelNames = append(modelNames, dep.Schema.Table) + } + } + + insertIntoOrderedList = func(name string) { + if _, ok := orderedModelNamesMap[name]; ok { + return // avoid loop + } + orderedModelNamesMap[name] = true + + if autoAdd { + dep := valuesMap[name] + for _, d := range dep.Depends { + if _, ok := valuesMap[d.Table]; ok { + insertIntoOrderedList(d.Table) + } else { + parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) + insertIntoOrderedList(d.Table) + } + } + } + + orderedModelNames = append(orderedModelNames, name) + } + + for _, value := range values { + if v, ok := value.(string); ok { + results = append(results, v) + } else { + parseDependence(value, true) + } + } + + for _, name := range modelNames { + insertIntoOrderedList(name) + } + + for _, name := range orderedModelNames { + results = append(results, valuesMap[name].Statement.Dest) + } + return +} + +func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { + if stmt.TableExpr != nil { + return *stmt.TableExpr + } + return clause.Table{Name: stmt.Table} +} diff --git a/model.go b/model.go index f37ff7ea..3334d17c 100644 --- a/model.go +++ b/model.go @@ -2,13 +2,14 @@ package gorm import "time" -// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models +// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt +// It may be embedded into your model or you may build your own model without it // type User struct { // gorm.Model // } type Model struct { - ID uint `gorm:"primary_key"` + ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` + DeletedAt DeletedAt `gorm:"index"` } diff --git a/model_struct.go b/model_struct.go deleted file mode 100644 index 205b8ba4..00000000 --- a/model_struct.go +++ /dev/null @@ -1,610 +0,0 @@ -package gorm - -import ( - "database/sql" - "errors" - "go/ast" - "reflect" - "strings" - "sync" - "time" - - "github.com/jinzhu/inflection" -) - -// DefaultTableNameHandler default table name handler -var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { - return defaultTableName -} - -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() - -// ModelStruct model definition -type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - defaultTableName string -} - -// TableName get model's table name -func (s *ModelStruct) TableName(db *DB) string { - if s.defaultTableName == "" && db != nil && s.ModelType != nil { - // Set default table name - if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { - s.defaultTableName = tabler.TableName() - } else { - tableName := ToDBName(s.ModelType.Name()) - if db == nil || !db.parent.singularTable { - tableName = inflection.Plural(tableName) - } - s.defaultTableName = tableName - } - } - - return DefaultTableNameHandler(db, s.defaultTableName) -} - -// StructField model field's struct definition -type StructField struct { - DBName string - Name string - Names []string - IsPrimaryKey bool - IsNormal bool - IsIgnored bool - IsScanner bool - HasDefaultValue bool - Tag reflect.StructTag - TagSettings map[string]string - Struct reflect.StructField - IsForeignKey bool - Relationship *Relationship -} - -func (structField *StructField) clone() *StructField { - clone := &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, - TagSettings: map[string]string{}, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - } - - if structField.Relationship != nil { - relationship := *structField.Relationship - clone.Relationship = &relationship - } - - for key, value := range structField.TagSettings { - clone.TagSettings[key] = value - } - - return clone -} - -// Relationship described the relationship between models -type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - PolymorphicValue string - ForeignFieldNames []string - ForeignDBNames []string - AssociationForeignFieldNames []string - AssociationForeignDBNames []string - JoinTableHandler JoinTableHandlerInterface -} - -func getForeignField(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { - return field - } - } - return nil -} - -// GetModelStruct get value's model struct, relationships based on struct and tag definition -func (scope *Scope) GetModelStruct() *ModelStruct { - var modelStruct ModelStruct - // Scope value can't be nil - if scope.Value == nil { - return &modelStruct - } - - reflectType := reflect.ValueOf(scope.Value).Type() - for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr { - reflectType = reflectType.Elem() - } - - // Scope value need to be a struct - if reflectType.Kind() != reflect.Struct { - return &modelStruct - } - - // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value - } - - modelStruct.ModelType = reflectType - - // Get all fields - for i := 0; i < reflectType.NumField(); i++ { - if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { - field := &StructField{ - Struct: fieldStruct, - Name: fieldStruct.Name, - Names: []string{fieldStruct.Name}, - Tag: fieldStruct.Tag, - TagSettings: parseTagSetting(fieldStruct.Tag), - } - - // is ignored field - if _, ok := field.TagSettings["-"]; ok { - field.IsIgnored = true - } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - - if _, ok := field.TagSettings["DEFAULT"]; ok { - field.HasDefaultValue = true - } - - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey { - field.HasDefaultValue = true - } - - indirectType := fieldStruct.Type - for indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - fieldValue := reflect.New(indirectType).Interface() - if _, isScanner := fieldValue.(sql.Scanner); isScanner { - // is scanner - field.IsScanner, field.IsNormal = true, true - if indirectType.Kind() == reflect.Struct { - for i := 0; i < indirectType.NumField(); i++ { - for key, value := range parseTagSetting(indirectType.Field(i).Tag) { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value - } - } - } - } - } else if _, isTime := fieldValue.(*time.Time); isTime { - // is time - field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - // is embedded struct - for _, subField := range scope.New(fieldValue).GetModelStruct().StructFields { - subField = subField.clone() - subField.Names = append([]string{fieldStruct.Name}, subField.Names...) - if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok { - subField.DBName = prefix + subField.DBName - } - - if subField.IsPrimaryKey { - if _, ok := subField.TagSettings["PRIMARY_KEY"]; ok { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) - } else { - subField.IsPrimaryKey = false - } - } - - if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { - if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { - newJoinTableHandler := &JoinTableHandler{} - newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) - subField.Relationship.JoinTableHandler = newJoinTableHandler - } - } - - modelStruct.StructFields = append(modelStruct.StructFields, subField) - } - continue - } else { - // build relationships - switch indirectType.Kind() { - case reflect.Slice: - defer func(field *StructField) { - var ( - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - foreignKeys []string - associationForeignKeys []string - elemType = field.Struct.Type - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) - } - } - - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) - } - } - - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) - } - } - - for _, name := range associationForeignKeys { - - // In order to allow self-referencing many2many tables, the name - // may be followed by "=" to allow renaming the column - parts := strings.Split(name, "=") - name = parts[0] - - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - - // If a new name was provided for the field, use it - name = ToDBName(elemType.Name()) + "_" + field.DBName - if len(parts) > 1 { - name = parts[1] - } - - // join table foreign keys for association - joinTableDBName := name - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) - } - } - - joinTableHandler := JoinTableHandler{} - joinTableHandler.Setup(relationship, many2many, reflectType, elemType) - relationship.JoinTableHandler = &joinTableHandler - field.Relationship = relationship - } else { - // User has many comments, associationType is User, comment use UserID as foreign key - var associationType = reflectType.Name() - var toFields = toScope.GetStructFields() - relationship.Kind = "has_many" - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Dog has many toys, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('dogs') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Dog has multiple set of toys set name of the set (instead of default 'dogs') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+field.Name) - associationForeignKeys = append(associationForeignKeys, field.Name) - } - } else { - // generate foreign keys from defined association foreign keys - for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys - foreignField.IsForeignKey = true - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - }(field) - case reflect.Struct: - defer func(field *StructField) { - var ( - // user has one profile, associationType is User, profile use UserID as foreign key - // user belongs to profile, associationType is Profile, user use ProfileID as foreign key - associationType = reflectType.Name() - relationship = &Relationship{} - toScope = scope.New(reflect.New(field.Struct.Type).Interface()) - toFields = toScope.GetStructFields() - tagForeignKeys []string - tagAssociationForeignKeys []string - ) - - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") - } - - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") - } - - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { - // Cat has one toy, tag polymorphic is Owner, then associationType is Owner - // Toy use OwnerID, OwnerType ('cats') as foreign key - if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { - associationType = polymorphic - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - // if Cat has several different types of toys set name for each (instead of default 'cats') - if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok { - relationship.PolymorphicValue = value - } else { - relationship.PolymorphicValue = scope.TableName() - } - polymorphicType.IsForeignKey = true - } - } - - // Has One - { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, primaryField := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, associationType+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys form association foreign keys - for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - foreignKeys = append(foreignKeys, associationType+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate association foreign keys from foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, associationType) { - associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { - foreignField.IsForeignKey = true - // source foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) - - // association foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" - field.Relationship = relationship - } else { - var foreignKeys = tagForeignKeys - var associationForeignKeys = tagAssociationForeignKeys - - if len(foreignKeys) == 0 { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, primaryField := range toScope.PrimaryFields() { - foreignKeys = append(foreignKeys, field.Name+primaryField.Name) - associationForeignKeys = append(associationForeignKeys, primaryField.Name) - } - } else { - // generate foreign keys with association foreign keys - for _, associationForeignKey := range associationForeignKeys { - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - foreignKeys = append(foreignKeys, field.Name+foreignField.Name) - associationForeignKeys = append(associationForeignKeys, foreignField.Name) - } - } - } - } else { - // generate foreign keys & association foreign keys - if len(associationForeignKeys) == 0 { - for _, foreignKey := range foreignKeys { - if strings.HasPrefix(foreignKey, field.Name) { - associationForeignKey := strings.TrimPrefix(foreignKey, field.Name) - if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil { - associationForeignKeys = append(associationForeignKeys, associationForeignKey) - } - } - } - if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{toScope.PrimaryKey()} - } - } else if len(foreignKeys) != len(associationForeignKeys) { - scope.Err(errors.New("invalid foreign keys, should have same length")) - return - } - } - - for idx, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { - foreignField.IsForeignKey = true - - // association foreign keys - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) - - // source foreign keys - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - } - } - } - - if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" - field.Relationship = relationship - } - } - }(field) - default: - field.IsNormal = true - } - } - } - - // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { - field.DBName = value - } else { - field.DBName = ToDBName(fieldStruct.Name) - } - - modelStruct.StructFields = append(modelStruct.StructFields, field) - } - } - - if len(modelStruct.PrimaryFields) == 0 { - if field := getForeignField("id", modelStruct.StructFields); field != nil { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) - } - } - - modelStructsMap.Set(reflectType, &modelStruct) - - return &modelStruct -} - -// GetStructFields get model's field structs -func (scope *Scope) GetStructFields() (fields []*StructField) { - return scope.GetModelStruct().StructFields -} - -func parseTagSetting(tags reflect.StructTag) map[string]string { - setting := map[string]string{} - for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { - tags := strings.Split(str, ";") - for _, value := range tags { - v := strings.Split(value, ":") - k := strings.TrimSpace(strings.ToUpper(v[0])) - if len(v) >= 2 { - setting[k] = strings.Join(v[1:], ":") - } else { - setting[k] = k - } - } - } - return setting -} diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go deleted file mode 100644 index 32a14772..00000000 --- a/multi_primary_keys_test.go +++ /dev/null @@ -1,381 +0,0 @@ -package gorm_test - -import ( - "os" - "reflect" - "sort" - "testing" -) - -type Blog struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Subject string - Body string - Tags []Tag `gorm:"many2many:blog_tags;"` - SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"` - LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"` -} - -type Tag struct { - ID uint `gorm:"primary_key"` - Locale string `gorm:"primary_key"` - Value string - Blogs []*Blog `gorm:"many2many:blogs_tags"` -} - -func compareTags(tags []Tag, contents []string) bool { - var tagContents []string - for _, tag := range tags { - tagContents = append(tagContents, tag.Value) - } - sort.Strings(tagContents) - sort.Strings(contents) - return reflect.DeepEqual(tagContents, contents) -} - -func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - Tags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - - DB.Save(&blog) - if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) - if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("Tags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "Tags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("Tags").Find(&blog1) - if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog).Association("Tags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "Tags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("Tags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("Tags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "Tags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("Tags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog).Association("Tags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "Tags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog).Association("Tags").Clear() - if DB.Model(&blog).Association("Tags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("shared_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - SharedTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { - t.Errorf("Blog should has two tags") - } - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) - if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("SharedTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - var blog1 Blog - DB.Preload("SharedTags").Find(&blog1) - if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("SharedTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "SharedTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { - t.Errorf("Should find 3 tags with Related") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - DB.Model(&blog2).Related(&tags2, "SharedTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 2 { - t.Errorf("Blog should has three tags after Replace") - } - - // Delete - DB.Model(&blog).Association("SharedTags").Delete(tag5) - var tags3 []Tag - DB.Model(&blog).Related(&tags3, "SharedTags") - if !compareTags(tags3, []string{"tag6"}) { - t.Errorf("Should find 1 tags after Delete") - } - - if DB.Model(&blog).Association("SharedTags").Count() != 1 { - t.Errorf("Blog should has three tags after Delete") - } - - DB.Model(&blog2).Association("SharedTags").Delete(tag3) - var tags4 []Tag - DB.Model(&blog).Related(&tags4, "SharedTags") - if !compareTags(tags4, []string{"tag6"}) { - t.Errorf("Tag should not be deleted when Delete with a unrelated tag") - } - - // Clear - DB.Model(&blog2).Association("SharedTags").Clear() - if DB.Model(&blog).Association("SharedTags").Count() != 0 { - t.Errorf("All tags should be cleared") - } - } -} - -func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" && dialect != "mssql" { - DB.DropTable(&Blog{}, &Tag{}) - DB.DropTable("locale_blog_tags") - DB.CreateTable(&Blog{}, &Tag{}) - blog := Blog{ - Locale: "ZH", - Subject: "subject", - Body: "body", - LocaleTags: []Tag{ - {Locale: "ZH", Value: "tag1"}, - {Locale: "ZH", Value: "tag2"}, - }, - } - DB.Save(&blog) - - blog2 := Blog{ - ID: blog.ID, - Locale: "EN", - } - DB.Create(&blog2) - - // Append - var tag3 = &Tag{Locale: "ZH", Value: "tag3"} - DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) - if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("Blog should has three tags after Append") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog should has 0 tags after ZH Blog Append") - } - - var tags []Tag - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if len(tags) != 0 { - t.Errorf("Should find 0 tags with Related for EN Blog") - } - - var blog1 Blog - DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) - if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Preload many2many relations") - } - - var tag4 = &Tag{Locale: "ZH", Value: "tag4"} - DB.Model(&blog2).Association("LocaleTags").Append(tag4) - - DB.Model(&blog).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("Should find 3 tags with Related for EN Blog") - } - - DB.Model(&blog2).Related(&tags, "LocaleTags") - if !compareTags(tags, []string{"tag4"}) { - t.Errorf("Should find 1 tags with Related for EN Blog") - } - - // Replace - var tag5 = &Tag{Locale: "ZH", Value: "tag5"} - var tag6 = &Tag{Locale: "ZH", Value: "tag6"} - DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) - - var tags2 []Tag - DB.Model(&blog).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - var blog11 Blog - DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) - if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { - t.Errorf("CN Blog's tags should not be changed after EN Blog Replace") - } - - DB.Model(&blog2).Related(&tags2, "LocaleTags") - if !compareTags(tags2, []string{"tag5", "tag6"}) { - t.Errorf("Should find 2 tags after Replace") - } - - var blog21 Blog - DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) - if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { - t.Errorf("EN Blog's tags should be changed after Replace") - } - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Replace") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after Replace") - } - - // Delete - DB.Model(&blog).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { - t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag") - } - - DB.Model(&blog2).Association("LocaleTags").Delete(tag5) - - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { - t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") - } - - // Clear - DB.Model(&blog2).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 3 { - t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags") - } - - DB.Model(&blog).Association("LocaleTags").Clear() - if DB.Model(&blog).Association("LocaleTags").Count() != 0 { - t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags") - } - - if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { - t.Errorf("EN Blog's tags should be cleared") - } - } -} diff --git a/pointer_test.go b/pointer_test.go deleted file mode 100644 index 2a68a5ab..00000000 --- a/pointer_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package gorm_test - -import "testing" - -type PointerStruct struct { - ID int64 - Name *string - Num *int -} - -type NormalStruct struct { - ID int64 - Name string - Num int -} - -func TestPointerFields(t *testing.T) { - DB.DropTable(&PointerStruct{}) - DB.AutoMigrate(&PointerStruct{}) - var name = "pointer struct 1" - var num = 100 - pointerStruct := PointerStruct{Name: &name, Num: &num} - if DB.Create(&pointerStruct).Error != nil { - t.Errorf("Failed to save pointer struct") - } - - var pointerStructResult PointerStruct - if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num { - t.Errorf("Failed to query saved pointer struct") - } - - var tableName = DB.NewScope(&PointerStruct{}).TableName() - - var normalStruct NormalStruct - DB.Table(tableName).First(&normalStruct) - if normalStruct.Name != name || normalStruct.Num != num { - t.Errorf("Failed to query saved Normal struct") - } - - var nilPointerStruct = PointerStruct{} - if err := DB.Create(&nilPointerStruct).Error; err != nil { - t.Error("Failed to save nil pointer struct", err) - } - - var pointerStruct2 PointerStruct - if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var normalStruct2 NormalStruct - if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil { - t.Error("Failed to query saved nil pointer struct", err) - } - - var partialNilPointerStruct1 = PointerStruct{Num: &num} - if err := DB.Create(&partialNilPointerStruct1).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct3 PointerStruct - if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct3 NormalStruct - if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num { - t.Error("Failed to query saved partial pointer struct", err) - } - - var partialNilPointerStruct2 = PointerStruct{Name: &name} - if err := DB.Create(&partialNilPointerStruct2).Error; err != nil { - t.Error("Failed to save partial nil pointer struct", err) - } - - var pointerStruct4 PointerStruct - if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name { - t.Error("Failed to query saved partial nil pointer struct", err) - } - - var normalStruct4 NormalStruct - if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name { - t.Error("Failed to query saved partial pointer struct", err) - } -} diff --git a/polymorphic_test.go b/polymorphic_test.go deleted file mode 100644 index d1ecfbbb..00000000 --- a/polymorphic_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package gorm_test - -import ( - "reflect" - "sort" - "testing" -) - -type Cat struct { - Id int - Name string - Toy Toy `gorm:"polymorphic:Owner;"` -} - -type Dog struct { - Id int - Name string - Toys []Toy `gorm:"polymorphic:Owner;"` -} - -type Hamster struct { - Id int - Name string - PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"` - OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"` -} - -type Toy struct { - Id int - Name string - OwnerId int - OwnerType string -} - -var compareToys = func(toys []Toy, contents []string) bool { - var toyContents []string - for _, toy := range toys { - toyContents = append(toyContents, toy.Name) - } - sort.Strings(toyContents) - sort.Strings(contents) - return reflect.DeepEqual(toyContents, contents) -} - -func TestPolymorphic(t *testing.T) { - cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}} - dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}} - DB.Save(&cat).Save(&dog) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Dog's toys count should be 2") - } - - // Query - var catToys []Toy - if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(catToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if catToys[0].Name != cat.Toy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - var dogToys []Toy - if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() { - t.Errorf("Did not find any polymorphic has many associations") - } else if len(dogToys) != len(dog.Toys) { - t.Errorf("Should have found all polymorphic has many associations") - } - - var catToy Toy - DB.Model(&cat).Association("Toy").Find(&catToy) - if catToy.Name != cat.Toy.Name { - t.Errorf("Should find has one polymorphic association") - } - - var dogToys1 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys1) - if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) { - t.Errorf("Should find has many polymorphic association") - } - - // Append - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - var catToy2 Toy - DB.Model(&cat).Association("Toy").Find(&catToy2) - if catToy2.Name != "cat toy 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 2 { - t.Errorf("Should return two polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 3", - }) - - var dogToys2 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys2) - if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) { - t.Errorf("Dog's toys should be updated with Append") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Replace - DB.Model(&cat).Association("Toy").Replace(&Toy{ - Name: "cat toy 3", - }) - - var catToy3 Toy - DB.Model(&cat).Association("Toy").Find(&catToy3) - if catToy3.Name != "cat toy 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1 after Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 3 { - t.Errorf("Should return three polymorphic has many associations") - } - - DB.Model(&dog).Association("Toys").Replace(&Toy{ - Name: "dog toy 4", - }, []Toy{ - {Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"}, - }) - - var dogToys3 []Toy - DB.Model(&dog).Association("Toys").Find(&dogToys3) - if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) { - t.Errorf("Dog's toys should be updated with Replace") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Should return three polymorphic has many associations") - } - - // Delete - DB.Model(&cat).Association("Toy").Delete(&catToy2) - - var catToy4 Toy - DB.Model(&cat).Association("Toy").Find(&catToy4) - if catToy4.Name != "cat toy 3" { - t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy") - } - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys count should be 1") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should be 4") - } - - DB.Model(&cat).Association("Toy").Delete(&catToy3) - - if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() { - t.Errorf("Toy should be deleted with Delete") - } - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys count should be 0 after Delete") - } - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete cat's toy") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys2) - - if DB.Model(&dog).Association("Toys").Count() != 4 { - t.Errorf("Dog's toys count should not be changed when delete unrelated toys") - } - - DB.Model(&dog).Association("Toys").Delete(&dogToys3) - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys count should be deleted with Delete") - } - - // Clear - DB.Model(&cat).Association("Toy").Append(&Toy{ - Name: "cat toy 2", - }) - - if DB.Model(&cat).Association("Toy").Count() != 1 { - t.Errorf("Cat's toys should be added with Append") - } - - DB.Model(&cat).Association("Toy").Clear() - - if DB.Model(&cat).Association("Toy").Count() != 0 { - t.Errorf("Cat's toys should be cleared with Clear") - } - - DB.Model(&dog).Association("Toys").Append(&Toy{ - Name: "dog toy 8", - }) - - if DB.Model(&dog).Association("Toys").Count() != 1 { - t.Errorf("Dog's toys should be added with Append") - } - - DB.Model(&dog).Association("Toys").Clear() - - if DB.Model(&dog).Association("Toys").Count() != 0 { - t.Errorf("Dog's toys should be cleared with Clear") - } -} - -func TestNamedPolymorphic(t *testing.T) { - hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} - DB.Save(&hamster) - - hamster2 := Hamster{} - DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) - if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { - t.Errorf("Hamster's preferred toy couldn't be preloaded") - } - if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name { - t.Errorf("Hamster's other toy couldn't be preloaded") - } - - // clear to omit Toy.Id in count - hamster2.PreferredToy = Toy{} - hamster2.OtherToy = Toy{} - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's preferred toy count should be 1") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy count should be 1") - } - - // Query - var hamsterToys []Toy - if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.PreferredToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() { - t.Errorf("Did not find any has one polymorphic association") - } else if len(hamsterToys) != 1 { - t.Errorf("Should have found only one polymorphic has one association") - } else if hamsterToys[0].Name != hamster.OtherToy.Name { - t.Errorf("Should have found the proper has one polymorphic association") - } - - hamsterToy := Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.PreferredToy.Name { - t.Errorf("Should find has one polymorphic association") - } - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != hamster.OtherToy.Name { - t.Errorf("Should find has one polymorphic association") - } - - // Append - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 2" { - t.Errorf("Should update has one polymorphic association with Append") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys count should be 1 after Append") - } - - // Replace - DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ - Name: "bike 3", - }) - DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ - Name: "treadmill 3", - }) - - hamsterToy = Toy{} - DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) - if hamsterToy.Name != "bike 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - hamsterToy = Toy{} - DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) - if hamsterToy.Name != "treadmill 3" { - t.Errorf("Should update has one polymorphic association with Replace") - } - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("hamster's toys count should be 1 after Replace") - } - - // Clear - DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ - Name: "bike 2", - }) - DB.Model(&hamster).Association("OtherToy").Append(&Toy{ - Name: "treadmill 2", - }) - - if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - if DB.Model(&hamster).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's toys should be added with Append") - } - - DB.Model(&hamster).Association("PreferredToy").Clear() - - if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { - t.Errorf("Hamster's preferred toy should be cleared with Clear") - } - if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { - t.Errorf("Hamster's other toy should be still available") - } - - DB.Model(&hamster).Association("OtherToy").Clear() - if DB.Model(&hamster).Association("OtherToy").Count() != 0 { - t.Errorf("Hamster's other toy should be cleared with Clear") - } -} diff --git a/prepare_stmt.go b/prepare_stmt.go new file mode 100644 index 00000000..14570061 --- /dev/null +++ b/prepare_stmt.go @@ -0,0 +1,169 @@ +package gorm + +import ( + "context" + "database/sql" + "sync" +) + +type Stmt struct { + *sql.Stmt + Transaction bool +} + +type PreparedStmtDB struct { + Stmts map[string]Stmt + PreparedSQL []string + Mux *sync.RWMutex + ConnPool +} + +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvalidDB +} + +func (db *PreparedStmtDB) Close() { + db.Mux.Lock() + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + stmt.Close() + } + } + + db.Mux.Unlock() +} + +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { + db.Mux.RLock() + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.RUnlock() + return stmt, nil + } + db.Mux.RUnlock() + + db.Mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + db.Mux.Unlock() + return stmt, nil + } else if ok { + stmt.Close() + } + + stmt, err := conn.PrepareContext(ctx, query) + if err == nil { + db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} + db.PreparedSQL = append(db.PreparedSQL, query) + } + db.Mux.Unlock() + + return db.Stmts[query], err +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.Mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.Mux.Unlock() + } + } + return result, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.Mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.Mux.Unlock() + } + } + return rows, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(ctx, db.ConnPool, false, query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + *sql.Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Mux.Unlock() + } + } + return result, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.Mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.Mux.Unlock() + } + } + return rows, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) + if err == nil { + return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/query_test.go b/query_test.go deleted file mode 100644 index def84e04..00000000 --- a/query_test.go +++ /dev/null @@ -1,667 +0,0 @@ -package gorm_test - -import ( - "fmt" - "reflect" - - "github.com/jinzhu/gorm" - - "testing" - "time" -) - -func TestFirstAndLast(t *testing.T) { - DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}}) - DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}}) - - var user1, user2, user3, user4 User - DB.First(&user1) - DB.Order("id").Limit(1).Find(&user2) - - ptrOfUser3 := &user3 - DB.Last(&ptrOfUser3) - DB.Order("id desc").Limit(1).Find(&user4) - if user1.Id != user2.Id || user3.Id != user4.Id { - t.Errorf("First and Last should by order by primary key") - } - - var users []User - DB.First(&users) - if len(users) != 1 { - t.Errorf("Find first record as slice") - } - - var user User - if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil { - t.Errorf("Should not raise any error when order with Join table") - } - - if user.Email != "" { - t.Errorf("User's Email should be blank as no one set it") - } -} - -func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) { - DB.Save(&Animal{Name: "animal1"}) - DB.Save(&Animal{Name: "animal2"}) - - var animal1, animal2, animal3, animal4 Animal - DB.First(&animal1) - DB.Order("counter").Limit(1).Find(&animal2) - - DB.Last(&animal3) - DB.Order("counter desc").Limit(1).Find(&animal4) - if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter { - t.Errorf("First and Last should work correctly") - } -} - -func TestFirstAndLastWithRaw(t *testing.T) { - user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}} - user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}} - DB.Save(&user1) - DB.Save(&user2) - - var user3, user4 User - DB.Raw("select * from users WHERE name = ?", "user").First(&user3) - if user3.Id != user1.Id { - t.Errorf("Find first record with raw") - } - - DB.Raw("select * from users WHERE name = ?", "user").Last(&user4) - if user4.Id != user2.Id { - t.Errorf("Find last record with raw") - } -} - -func TestUIntPrimaryKey(t *testing.T) { - var animal Animal - DB.First(&animal, uint64(1)) - if animal.Counter != 1 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } - - DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal) - if animal.Counter != 2 { - t.Errorf("Fetch a record from with a non-int primary key should work, but failed") - } -} - -func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { - type AddressByZipCode struct { - ZipCode string `gorm:"primary_key"` - Address string - } - - DB.AutoMigrate(&AddressByZipCode{}) - DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"}) - - var address AddressByZipCode - DB.First(&address, "00501") - if address.ZipCode != "00501" { - t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed") - } -} - -func TestFindAsSliceOfPointers(t *testing.T) { - DB.Save(&User{Name: "user"}) - - var users []User - DB.Find(&users) - - var userPointers []*User - DB.Find(&userPointers) - - if len(users) == 0 || len(users) != len(userPointers) { - t.Errorf("Find slice of pointers") - } -} - -func TestSearchWithPlainSQL(t *testing.T) { - user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%") - - if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() { - t.Errorf("Search with plain SQL") - } - - if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() { - t.Errorf("Search with plan SQL (regexp)") - } - - var users []User - DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1) - if len(users) != 2 { - t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) - } - - DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) - } - - scopedb.Where("age <> ?", 20).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users age != 20, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) - } - - scopedb.Where("birthday > ?", "2002-10-10").Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) - } - - scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) - } - - DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) - if len(users) != 2 { - t.Errorf("Should found 2 users, but got %v", len(users)) - } - - DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users, but got %v", len(users)) - } - - DB.Where("id in (?)", user1.Id).Find(&users) - if len(users) != 1 { - t.Errorf("Should found 1 users, but got %v", len(users)) - } - - if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil { - t.Error("no error should happen when query with empty slice, but got: ", err) - } - - if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() { - t.Errorf("Should not get RecordNotFound error when looking for none existing records") - } -} - -func TestSearchWithStruct(t *testing.T) { - user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where(user1.Id).First(&User{}).RecordNotFound() { - t.Errorf("Search with primary key") - } - - if DB.First(&User{}, user1.Id).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() { - t.Errorf("Search with primary key as inline condition") - } - - var users []User - DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users) - if len(users) != 3 { - t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users)) - } - - var user User - DB.First(&user, &User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline pointer of struct") - } - - DB.First(&user, User{Name: user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline struct") - } - - DB.Where(&User{Name: user1.Name}).First(&user) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with where struct") - } - - DB.Find(&users, &User{Name: user2.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline struct") - } -} - -func TestSearchWithMap(t *testing.T) { - companyID := 1 - user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4) - - var user User - DB.First(&user, map[string]interface{}{"name": user1.Name}) - if user.Id == 0 || user.Name != user1.Name { - t.Errorf("Search first record with inline map") - } - - user = User{} - DB.Where(map[string]interface{}{"name": user2.Name}).First(&user) - if user.Id == 0 || user.Name != user2.Name { - t.Errorf("Search first record with where map") - } - - var users []User - DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user3.Name}) - if len(users) != 1 { - t.Errorf("Search all records with inline map") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil}) - if len(users) != 0 { - t.Errorf("Search all records with inline map containing null value finding 0 records") - } - - DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil}) - if len(users) != 1 { - t.Errorf("Search all records with inline map containing null value finding 1 record") - } - - DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID}) - if len(users) != 1 { - t.Errorf("Search all records with inline multiple value map") - } -} - -func TestSearchWithEmptyChain(t *testing.T) { - user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")} - user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")} - user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")} - DB.Save(&user1).Save(&user2).Save(&user3) - - if DB.Where("").Where("").First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty strings") - } - - if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty struct") - } - - if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil { - t.Errorf("Should not raise any error if searching with empty map") - } -} - -func TestSelect(t *testing.T) { - user1 := User{Name: "SelectUser1"} - DB.Save(&user1) - - var user User - DB.Where("name = ?", user1.Name).Select("name").Find(&user) - if user.Id != 0 { - t.Errorf("Should not have ID because only selected name, %+v", user.Id) - } - - if user.Name != user1.Name { - t.Errorf("Should have user Name when selected it") - } -} - -func TestOrderAndPluck(t *testing.T) { - user1 := User{Name: "OrderPluckUser1", Age: 1} - user2 := User{Name: "OrderPluckUser2", Age: 10} - user3 := User{Name: "OrderPluckUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%") - - var user User - scopedb.Order(gorm.Expr("case when name = ? then 0 else 1 end", "OrderPluckUser2")).First(&user) - if user.Name != "OrderPluckUser2" { - t.Errorf("Order with sql expression") - } - - var ages []int64 - scopedb.Order("age desc").Pluck("age", &ages) - if ages[0] != 20 { - t.Errorf("The first age should be 20 when order with age desc") - } - - var ages1, ages2 []int64 - scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2) - if !reflect.DeepEqual(ages1, ages2) { - t.Errorf("The first order is the primary order") - } - - var ages3, ages4 []int64 - scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4) - if reflect.DeepEqual(ages3, ages4) { - t.Errorf("Reorder should work") - } - - var names []string - var ages5 []int64 - scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names) - if names != nil && ages5 != nil { - if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) { - t.Errorf("Order with multiple orders") - } - } else { - t.Errorf("Order with multiple orders") - } - - var ages6 []int64 - if err := scopedb.Order("").Pluck("age", &ages6).Error; err != nil { - t.Errorf("An empty string as order clause produces invalid queries") - } - - DB.Model(User{}).Select("name, age").Find(&[]User{}) -} - -func TestLimit(t *testing.T) { - user1 := User{Name: "LimitUser1", Age: 1} - user2 := User{Name: "LimitUser2", Age: 10} - user3 := User{Name: "LimitUser3", Age: 20} - user4 := User{Name: "LimitUser4", Age: 10} - user5 := User{Name: "LimitUser5", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5) - - var users1, users2, users3 []User - DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) - - if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") - } -} - -func TestOffset(t *testing.T) { - for i := 0; i < 20; i++ { - DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) - } - var users1, users2, users3, users4 []User - DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) - - if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { - t.Errorf("Offset should work") - } -} - -func TestOr(t *testing.T) { - user1 := User{Name: "OrUser1", Age: 1} - user2 := User{Name: "OrUser2", Age: 10} - user3 := User{Name: "OrUser3", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users []User - DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users) - if len(users) != 2 { - t.Errorf("Find users with or") - } -} - -func TestCount(t *testing.T) { - user1 := User{Name: "CountUser1", Age: 1} - user2 := User{Name: "CountUser2", Age: 10} - user3 := User{Name: "CountUser3", Age: 20} - - DB.Save(&user1).Save(&user2).Save(&user3) - var count, count1, count2 int64 - var users []User - - if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { - t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) - } - - if count != int64(len(users)) { - t.Errorf("Count() method should get correct value") - } - - DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2) - if count1 != 1 || count2 != 3 { - t.Errorf("Multiple count in chain") - } -} - -func TestNot(t *testing.T) { - DB.Create(getPreparedUser("user1", "not")) - DB.Create(getPreparedUser("user2", "not")) - DB.Create(getPreparedUser("user3", "not")) - - user4 := getPreparedUser("user4", "not") - user4.Company = Company{} - DB.Create(user4) - - DB := DB.Where("role = ?", "not") - - var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User - if DB.Find(&users1).RowsAffected != 4 { - t.Errorf("should find 4 not users") - } - DB.Not(users1[0].Id).Find(&users2) - - if len(users1)-len(users2) != 1 { - t.Errorf("Should ignore the first users with Not") - } - - DB.Not([]int{}).Find(&users3) - if len(users1)-len(users3) != 0 { - t.Errorf("Should find all users with a blank condition") - } - - var name3Count int64 - DB.Table("users").Where("name = ?", "user3").Count(&name3Count) - DB.Not("name", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name = ?", "user3").Find(&users4) - if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not("name <> ?", "user3").Find(&users4) - if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(User{Name: "user3"}).Find(&users5) - - if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) - if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) - if len(users1)-len(users7) != 2 { // not user3 or user4 - t.Errorf("Should find all user's name not equal to 3 who do not have company id") - } - - DB.Not("name", []string{"user3"}).Find(&users8) - if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") - } - - var name2Count int64 - DB.Table("users").Where("name = ?", "user2").Count(&name2Count) - DB.Not("name", []string{"user3", "user2"}).Find(&users9) - if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") - } -} - -func TestFillSmallerStruct(t *testing.T) { - user1 := User{Name: "SmallerUser", Age: 100} - DB.Save(&user1) - type SimpleUser struct { - Name string - Id int64 - UpdatedAt time.Time - CreatedAt time.Time - } - - var simpleUser SimpleUser - DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser) - - if simpleUser.Id == 0 || simpleUser.Name == "" { - t.Errorf("Should fill data correctly into smaller struct") - } -} - -func TestFindOrInitialize(t *testing.T) { - var user1, user2, user3, user4, user5, user6 User - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1) - if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) - if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 { - t.Errorf("user should be initialized with search value") - } - - DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) - if user3.Name != "find or init 2" || user3.Id != 0 { - t.Errorf("user should be initialized with inline search value") - } - - DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and attrs") - } - - DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) - if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 { - t.Errorf("user should be initialized with search value and assign attrs") - } - - DB.Save(&User{Name: "find or init", Age: 33}) - DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 { - t.Errorf("user should be found with FirstOrInit") - } - - DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) - if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } -} - -func TestFindOrCreate(t *testing.T) { - var user1, user2, user3, user4, user5, user6, user7, user8 User - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1) - if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) - if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 { - t.Errorf("user should be created with search value") - } - - DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) - if user3.Name != "find or create 2" || user3.Id == 0 { - t.Errorf("user should be created with inline search value") - } - - DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) - if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and attrs") - } - - updatedAt1 := user4.UpdatedAt - DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) - if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateAt should be changed when update values with assign") - } - - DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) - if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 { - t.Errorf("user should be created with search value and assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) - if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 { - t.Errorf("user should be found and not initialized by Attrs") - } - - DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) - if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create"}).Find(&user7) - if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 { - t.Errorf("user should be found and updated with assigned attrs") - } - - DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8) - if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() { - t.Errorf("embedded struct email should be saved") - } - - if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() { - t.Errorf("embedded struct credit card should be saved") - } -} - -func TestSelectWithEscapedFieldName(t *testing.T) { - user1 := User{Name: "EscapedFieldNameUser", Age: 1} - user2 := User{Name: "EscapedFieldNameUser", Age: 10} - user3 := User{Name: "EscapedFieldNameUser", Age: 20} - DB.Save(&user1).Save(&user2).Save(&user3) - - var names []string - DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names) - - if len(names) != 3 { - t.Errorf("Expected 3 name, but got: %d", len(names)) - } -} - -func TestSelectWithVariables(t *testing.T) { - DB.Save(&User{Name: "jinzhu"}) - - rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() - - if !rows.Next() { - t.Errorf("Should have returned at least one row") - } else { - columns, _ := rows.Columns() - if !reflect.DeepEqual(columns, []string{"fake"}) { - t.Errorf("Should only contains one column") - } - } - - rows.Close() -} - -func TestSelectWithArrayInput(t *testing.T) { - DB.Save(&User{Name: "jinzhu", Age: 42}) - - var user User - DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user) - - if user.Name != "jinzhu" || user.Age != 42 { - t.Errorf("Should have selected both age and name") - } -} diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..e82e3f07 --- /dev/null +++ b/scan.go @@ -0,0 +1,251 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "strings" + "time" + + "gorm.io/gorm/schema" +) + +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + +func Scan(rows *sql.Rows, db *DB, initialized bool) { + columns, _ := rows.Columns() + values := make([]interface{}, len(columns)) + db.RowsAffected = 0 + + switch dest := db.Statement.Dest.(type) { + case map[string]interface{}, *map[string]interface{}: + if initialized || rows.Next() { + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + scanIntoMap(mapValue, values, columns) + } + case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() + for initialized || rows.Next() { + prepareValues(values, db, columnTypes, columns) + + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + mapValue := map[string]interface{}{} + scanIntoMap(mapValue, values, columns) + *dest = append(*dest, mapValue) + } + case *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, + *float32, *float64, + *bool, *string, *time.Time, + *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, + *sql.NullBool, *sql.NullString, *sql.NullTime: + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + db.AddError(rows.Scan(dest)) + } + default: + Schema := db.Statement.Schema + + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + var ( + reflectValueType = db.Statement.ReflectValue.Type().Elem() + isPtr = reflectValueType.Kind() == reflect.Ptr + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + ) + + if isPtr { + reflectValueType = reflectValueType.Elem() + } + + db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + + if Schema != nil { + if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + fields[idx] = field + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field + + if len(joinFields) == 0 { + joinFields = make([][2]*schema.Field, len(columns)) + } + joinFields[idx] = [2]*schema.Field{rel.Field, field} + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + } + + // pluck values into slice of data + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner + reflectValueType.Kind() != reflect.Struct || // is not struct + Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + isPluck = true + } + } + + for initialized || rows.Next() { + initialized = false + db.RowsAffected++ + + elem := reflect.New(reflectValueType) + if isPluck { + db.AddError(rows.Scan(elem.Interface())) + } else { + for idx, field := range fields { + if field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } + } + + db.AddError(rows.Scan(values...)) + + for idx, field := range fields { + if len(joinFields) != 0 && joinFields[idx][0] != nil { + value := reflect.ValueOf(values[idx]).Elem() + relValue := joinFields[idx][0].ReflectValueOf(elem) + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } else if field != nil { + field.Set(elem, values[idx]) + } + } + } + + if isPtr { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + } else { + db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + } + } + case reflect.Struct, reflect.Ptr: + if db.Statement.ReflectValue.Type() != Schema.ModelType { + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + } + + if initialized || rows.Next() { + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } + } + + db.RowsAffected++ + db.AddError(rows.Scan(values...)) + + for idx, column := range columns { + if field := Schema.LookUpField(column); field != nil && field.Readable { + field.Set(db.Statement.ReflectValue, values[idx]) + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + value := reflect.ValueOf(values[idx]).Elem() + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } + } + } + } + } + } + } + + if err := rows.Err(); err != nil && err != db.Error { + db.AddError(err) + } + + if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { + db.AddError(ErrRecordNotFound) + } +} diff --git a/scaner_test.go b/scaner_test.go deleted file mode 100644 index 9e251dd6..00000000 --- a/scaner_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package gorm_test - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "testing" - - "github.com/jinzhu/gorm" -) - -func TestScannableSlices(t *testing.T) { - if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil { - t.Errorf("Should create table with slice values correctly: %s", err) - } - - r1 := RecordWithSlice{ - Strings: ExampleStringSlice{"a", "b", "c"}, - Structs: ExampleStructSlice{ - {"name1", "value1"}, - {"name2", "value2"}, - }, - } - - if err := DB.Save(&r1).Error; err != nil { - t.Errorf("Should save record with slice values") - } - - var r2 RecordWithSlice - - if err := DB.Find(&r2).Error; err != nil { - t.Errorf("Should fetch record with slice values") - } - - if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" { - t.Errorf("Should have serialised and deserialised a string array") - } - - if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" { - t.Errorf("Should have serialised and deserialised a struct array") - } -} - -type RecordWithSlice struct { - ID uint64 - Strings ExampleStringSlice `sql:"type:text"` - Structs ExampleStructSlice `sql:"type:text"` -} - -type ExampleStringSlice []string - -func (l ExampleStringSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStringSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ExampleStruct struct { - Name string - Value string -} - -type ExampleStructSlice []ExampleStruct - -func (l ExampleStructSlice) Value() (driver.Value, error) { - bytes, err := json.Marshal(l) - return string(bytes), err -} - -func (l *ExampleStructSlice) Scan(input interface{}) error { - switch value := input.(type) { - case string: - return json.Unmarshal([]byte(value), l) - case []byte: - return json.Unmarshal(value, l) - default: - return errors.New("not supported") - } -} - -type ScannerDataType struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct struct { - Field1 int - ScannerDataType *ScannerDataType `sql:"TYPE:json"` -} - -type ScannerDataType2 struct { - Street string `sql:"TYPE:varchar(24)"` -} - -func (ScannerDataType2) Value() (driver.Value, error) { - return nil, nil -} - -func (*ScannerDataType2) Scan(input interface{}) error { - return nil -} - -type ScannerDataTypeTestStruct2 struct { - Field1 int - ScannerDataType *ScannerDataType2 -} - -func TestScannerDataType(t *testing.T) { - scope := gorm.Scope{Value: &ScannerDataTypeTestStruct{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "json" { - t.Errorf("data type for scanner is wrong") - } - } - - scope = gorm.Scope{Value: &ScannerDataTypeTestStruct2{}} - if field, ok := scope.FieldByName("ScannerDataType"); ok { - if DB.Dialect().DataTypeOf(field.StructField) != "varchar(24)" { - t.Errorf("data type for scanner is wrong") - } - } -} diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go new file mode 100644 index 00000000..dec41eba --- /dev/null +++ b/schema/callbacks_test.go @@ -0,0 +1,40 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type UserWithCallback struct { +} + +func (UserWithCallback) BeforeSave(*gorm.DB) error { + return nil +} + +func (UserWithCallback) AfterCreate(*gorm.DB) error { + return nil +} + +func TestCallback(t *testing.T) { + user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user with callback, got error %v", err) + } + + for _, str := range []string{"BeforeSave", "AfterCreate"} { + if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be true", str) + } + } + + for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { + if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be false", str) + } + } +} diff --git a/schema/check.go b/schema/check.go new file mode 100644 index 00000000..161a6ac6 --- /dev/null +++ b/schema/check.go @@ -0,0 +1,37 @@ +package schema + +import ( + "regexp" + "strings" +) + +var ( + // reg match english letters and midline + regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") +) + +type Check struct { + Name string + Constraint string // length(phone) >= 10 + *Field +} + +// ParseCheckConstraints parse schema check constraints +func (schema *Schema) ParseCheckConstraints() map[string]Check { + var checks = map[string]Check{} + for _, field := range schema.FieldsByDBName { + if chk := field.TagSettings["CHECK"]; chk != "" { + names := strings.Split(chk, ",") + if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { + checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} + } else { + if names[0] == "" { + chk = strings.Join(names[1:], ",") + } + name := schema.namer.CheckerName(schema.Table, field.DBName) + checks[name] = Check{Name: name, Constraint: chk, Field: field} + } + } + } + return checks +} diff --git a/schema/check_test.go b/schema/check_test.go new file mode 100644 index 00000000..eda043b7 --- /dev/null +++ b/schema/check_test.go @@ -0,0 +1,55 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm/schema" +) + +type UserCheck struct { + Name string `gorm:"check:name_checker,name <> 'jinzhu'"` + Name2 string `gorm:"check:name <> 'jinzhu'"` + Name3 string `gorm:"check:,name <> 'jinzhu'"` +} + +func TestParseCheck(t *testing.T) { + user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user check, got error %v", err) + } + + results := map[string]schema.Check{ + "name_checker": { + Name: "name_checker", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name2": { + Name: "chk_user_checks_name2", + Constraint: "name <> 'jinzhu'", + }, + "chk_user_checks_name3": { + Name: "chk_user_checks_name3", + Constraint: "name <> 'jinzhu'", + }, + } + + checks := user.ParseCheckConstraints() + + for k, result := range results { + v, ok := checks[k] + if !ok { + t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) + } + + for _, name := range []string{"Name", "Constraint"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "check %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + } +} diff --git a/schema/field.go b/schema/field.go new file mode 100644 index 00000000..5dbc96f1 --- /dev/null +++ b/schema/field.go @@ -0,0 +1,834 @@ +package schema + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm/utils" +) + +type DataType string + +type TimeType int64 + +var TimeReflectType = reflect.TypeOf(time.Time{}) + +const ( + UnixSecond TimeType = 1 + UnixMillisecond TimeType = 2 + UnixNanosecond TimeType = 3 +) + +const ( + Bool DataType = "bool" + Int DataType = "int" + Uint DataType = "uint" + Float DataType = "float" + String DataType = "string" + Time DataType = "time" + Bytes DataType = "bytes" +) + +type Field struct { + Name string + DBName string + BindNames []string + DataType DataType + GORMDataType DataType + PrimaryKey bool + AutoIncrement bool + AutoIncrementIncrement int64 + Creatable bool + Updatable bool + Readable bool + HasDefaultValue bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + Scale int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + OwnerSchema *Schema + ReflectValueOf func(reflect.Value) reflect.Value + ValueOf func(reflect.Value) (value interface{}, zero bool) + Set func(reflect.Value, interface{}) error + IgnoreMigration bool +} + +func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var err error + + field := &Field{ + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Readable: true, + Tag: fieldStruct.Tag, + TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), + Schema: schema, + AutoIncrementIncrement: 1, + } + + for field.IndirectFieldType.Kind() == reflect.Ptr { + field.IndirectFieldType = field.IndirectFieldType.Elem() + } + + fieldValue := reflect.New(field.IndirectFieldType) + // if field is valuer, used its value or first fields as data type + valuer, isValuer := fieldValue.Interface().(driver.Valuer) + if isValuer { + if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { + fieldValue = reflect.ValueOf(v) + } + + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type + for newFieldType.Kind() == reflect.Ptr { + newFieldType = newFieldType.Elem() + } + + fieldValue = reflect.New(newFieldType) + + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } + } + } + } + } + + getRealFieldValue(fieldValue) + } + } + + if dbName, ok := field.TagSettings["COLUMN"]; ok { + field.DBName = dbName + } + + if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + field.PrimaryKey = true + } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + field.PrimaryKey = true + } + + if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { + field.AutoIncrement = true + field.HasDefaultValue = true + } + + if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { + field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) + } + + if v, ok := field.TagSettings["DEFAULT"]; ok { + field.HasDefaultValue = true + field.DefaultValue = v + } + + if num, ok := field.TagSettings["SIZE"]; ok { + if field.Size, err = strconv.Atoi(num); err != nil { + field.Size = -1 + } + } + + if p, ok := field.TagSettings["PRECISION"]; ok { + field.Precision, _ = strconv.Atoi(p) + } + + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { + field.NotNull = true + } else if val, ok := field.TagSettings["NOTNULL"]; ok && utils.CheckTruth(val) { + field.NotNull = true + } + + if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { + field.Unique = true + } + + if val, ok := field.TagSettings["COMMENT"]; ok { + field.Comment = val + } + + // default value is function or null or blank (primary keys) + field.DefaultValue = strings.TrimSpace(field.DefaultValue) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Bool: + field.DataType = Bool + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.DataType = Int + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.DataType = Uint + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + } + } + case reflect.Float32, reflect.Float64: + field.DataType = Float + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + } + } + case reflect.String: + field.DataType = String + + if field.HasDefaultValue && !skipParseDefaultValue { + field.DefaultValue = strings.Trim(field.DefaultValue, "'") + field.DefaultValue = strings.Trim(field.DefaultValue, "\"") + field.DefaultValueInterface = field.DefaultValue + } + case reflect.Struct: + if _, ok := fieldValue.Interface().(*time.Time); ok { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { + field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + field.DataType = Time + } + case reflect.Array, reflect.Slice: + if reflect.Indirect(fieldValue).Type().Elem() == reflect.TypeOf(uint8(0)) { + field.DataType = Bytes + } + } + + field.GORMDataType = field.DataType + + if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { + field.DataType = DataType(dataTyper.GormDataType()) + } + + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { + if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + + if val, ok := field.TagSettings["TYPE"]; ok { + switch DataType(strings.ToLower(val)) { + case Bool, Int, Uint, Float, String, Time, Bytes: + field.DataType = DataType(strings.ToLower(val)) + default: + field.DataType = DataType(val) + } + } + + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + + if field.Size == 0 { + switch reflect.Indirect(fieldValue).Kind() { + case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: + field.Size = 64 + case reflect.Int8, reflect.Uint8: + field.Size = 8 + case reflect.Int16, reflect.Uint16: + field.Size = 16 + case reflect.Int32, reflect.Uint32, reflect.Float32: + field.Size = 32 + } + } + + // setup permission + if val, ok := field.TagSettings["-"]; ok { + val = strings.ToLower(strings.TrimSpace(val)) + switch val { + case "-": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + case "all": + field.Creatable = false + field.Updatable = false + field.Readable = false + field.DataType = "" + field.IgnoreMigration = true + case "migration": + field.IgnoreMigration = true + } + } + + if v, ok := field.TagSettings["->"]; ok { + field.Creatable = false + field.Updatable = false + if strings.ToLower(v) == "false" { + field.Readable = false + } else { + field.Readable = true + } + } + + if v, ok := field.TagSettings["<-"]; ok { + field.Creatable = true + field.Updatable = true + + if v != "<-" { + if !strings.Contains(v, "create") { + field.Creatable = false + } + + if !strings.Contains(v, "update") { + field.Updatable = false + } + } + } + + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { + if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + var err error + field.Creatable = false + field.Updatable = false + field.Readable = false + + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) + if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { + schema.err = err + } + + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { + ef.DBName = prefix + ef.DBName + } + + if ef.PrimaryKey { + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } + } + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + } else { + schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + } + } + + return field +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // ValueOf + switch { + case len(field.StructField.Index) == 1: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + return fieldValue.Interface(), fieldValue.IsZero() + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + return fieldValue.Interface(), fieldValue.IsZero() + } + default: + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + v := reflect.Indirect(value) + + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + + if v.Type().Elem().Kind() == reflect.Struct { + if !v.IsNil() { + v = v.Elem() + } else { + return nil, true + } + } else { + return nil, true + } + } + } + return v.Interface(), v.IsZero() + } + } + + // ReflectValueOf + switch { + case len(field.StructField.Index) == 1: + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]) + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + } + default: + field.ReflectValueOf = func(value reflect.Value) reflect.Value { + v := reflect.Indirect(value) + for idx, fieldIdx := range field.StructField.Index { + if fieldIdx >= 0 { + v = v.Field(fieldIdx) + } else { + v = v.Field(-fieldIdx - 1) + } + + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + } + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } + } + } + return v + } + } + + fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + if v == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + reflectV := reflect.ValueOf(v) + // Optimal value type acquisition for v + reflectValType := reflectV.Type() + + if reflectValType.AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + return + } else if reflectValType.ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + return + } else if field.FieldType.Kind() == reflect.Ptr { + fieldValue := field.ReflectValueOf(value) + + if reflectValType.AssignableTo(field.FieldType.Elem()) { + if !fieldValue.IsValid() { + fieldValue = reflect.New(field.FieldType.Elem()) + } else if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflectV) + return + } else if reflectValType.ConvertibleTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + return + } + } + + if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + err = setter(value, reflectV.Elem().Interface()) + } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + + return + } + + // Set + switch field.FieldType.Kind() { + case reflect.Bool: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case bool: + field.ReflectValueOf(value).SetBool(data) + case *bool: + if data != nil { + field.ReflectValueOf(value).SetBool(*data) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case int64: + if data > 0 { + field.ReflectValueOf(value).SetBool(true) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(value).SetBool(b) + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case int64: + field.ReflectValueOf(value).SetInt(data) + case int: + field.ReflectValueOf(value).SetInt(int64(data)) + case int8: + field.ReflectValueOf(value).SetInt(int64(data)) + case int16: + field.ReflectValueOf(value).SetInt(int64(data)) + case int32: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint8: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint16: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint32: + field.ReflectValueOf(value).SetInt(int64(data)) + case uint64: + field.ReflectValueOf(value).SetInt(int64(data)) + case float32: + field.ReflectValueOf(value).SetInt(int64(data)) + case float64: + field.ReflectValueOf(value).SetInt(int64(data)) + case []byte: + return field.Set(value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValueOf(value).SetInt(i) + } else { + return err + } + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(value).SetInt(0) + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case uint64: + field.ReflectValueOf(value).SetUint(data) + case uint: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint8: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint16: + field.ReflectValueOf(value).SetUint(uint64(data)) + case uint32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int64: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int8: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int16: + field.ReflectValueOf(value).SetUint(uint64(data)) + case int32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case float32: + field.ReflectValueOf(value).SetUint(uint64(data)) + case float64: + field.ReflectValueOf(value).SetUint(uint64(data)) + case []byte: + return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValueOf(value).SetUint(i) + } else { + return err + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.Float32, reflect.Float64: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case float64: + field.ReflectValueOf(value).SetFloat(data) + case float32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int64: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int8: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int16: + field.ReflectValueOf(value).SetFloat(float64(data)) + case int32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint8: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint16: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint32: + field.ReflectValueOf(value).SetFloat(float64(data)) + case uint64: + field.ReflectValueOf(value).SetFloat(float64(data)) + case []byte: + return field.Set(value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValueOf(value).SetFloat(i) + } else { + return err + } + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + case reflect.String: + field.Set = func(value reflect.Value, v interface{}) (err error) { + switch data := v.(type) { + case string: + field.ReflectValueOf(value).SetString(data) + case []byte: + field.ReflectValueOf(value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValueOf(value).SetString(utils.ToString(data)) + case float64, float32: + field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + return fallbackSetter(value, v, field.Set) + } + return err + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Elem().Interface().(type) { + case time.Time: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + case *time.Time: + if data != nil { + field.ReflectValueOf(value).Set(reflect.ValueOf(data).Elem()) + } else { + field.ReflectValueOf(value).Set(reflect.ValueOf(time.Time{})) + } + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValueOf(value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + case *time.Time: + field.Set = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + if v == "" { + return nil + } + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + fieldValue.Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fallbackSetter(value, v, field.Set) + } + return nil + } + default: + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = fieldValue.Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else { + field.Set = func(value reflect.Value, v interface{}) (err error) { + return fallbackSetter(value, v, field.Set) + } + } + } + } +} diff --git a/schema/field_test.go b/schema/field_test.go new file mode 100644 index 00000000..64f4a909 --- /dev/null +++ b/schema/field_test.go @@ -0,0 +1,260 @@ +package schema_test + +import ( + "database/sql" + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestFieldValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ + Model: gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, + }, + Name: "valuer_and_setter", + Age: 18, + Birthday: tests.Now(), + Active: true, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "updated_at": user.UpdatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + var f *bool + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "updated_at": nil, + "deleted_at": time.Now(), + "age": 20, + "birthday": time.Now(), + "active": f, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + newValues["updated_at"] = time.Time{} + newValues["active"] = false + checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age := myint(10) + var nilTime *time.Time + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "updated_at": nilTime, + "deleted_at": time.Now(), + "age": &age, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + newValues2["updated_at"] = time.Time{} + checkField(t, userSchema, reflectValue, newValues2) +} + +func TestPointerFieldValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ + Model: &gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, + }, + Name: &name, + Age: &age, + Birthday: tests.Now(), + Active: &active, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "deleted_at": time.Now(), + "age": 20, + "birthday": time.Now(), + "active": false, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age2 := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age2, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) +} + +func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ + ID: sql.NullInt64{Int64: 10, Valid: true}, + Name: &sql.NullString{String: name, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + RegisteredAt: mytime(time.Now()), + DeletedAt: &deletedAt, + Active: mybool(true), + Admin: &isAdmin, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "id": user.ID, + "name": user.Name, + "birthday": user.Birthday, + "registered_at": user.RegisteredAt, + "deleted_at": user.DeletedAt, + "active": user.Active, + "admin": user.Admin, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newDeletedAt := mytime(time.Now()) + newIsAdmin := mybool(true) + newValues := map[string]interface{}{ + "id": sql.NullInt64{Int64: 1, Valid: true}, + "name": &sql.NullString{String: name + "rename", Valid: true}, + "birthday": time.Now(), + "registered_at": mytime(time.Now()), + "deleted_at": &newDeletedAt, + "active": mybool(false), + "admin": &newIsAdmin, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + newValues2 := map[string]interface{}{ + "id": 5, + "name": name + "rename2", + "birthday": time.Now(), + "registered_at": time.Now(), + "deleted_at": time.Now(), + "active": true, + "admin": false, + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) +} + +type UserWithPermissionControl struct { + ID uint + Name string `gorm:"-"` + Name2 string `gorm:"->"` + Name3 string `gorm:"<-"` + Name4 string `gorm:"<-:create"` + Name5 string `gorm:"<-:update"` + Name6 string `gorm:"<-:create,update"` + Name7 string `gorm:"->:false;<-:create,update"` +} + +func TestParseFieldWithPermission(t *testing.T) { + user, err := schema.Parse(&UserWithPermissionControl{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse user with permission, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, + {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, + {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, + {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true}, + {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, + {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, + {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) {}) + } +} diff --git a/schema/index.go b/schema/index.go new file mode 100644 index 00000000..b54e08ad --- /dev/null +++ b/schema/index.go @@ -0,0 +1,141 @@ +package schema + +import ( + "sort" + "strconv" + "strings" +) + +type Index struct { + Name string + Class string // UNIQUE | FULLTEXT | SPATIAL + Type string // btree, hash, gist, spgist, gin, and brin + Where string + Comment string + Option string // WITH PARSER parser_name + Fields []IndexOption +} + +type IndexOption struct { + *Field + Expression string + Sort string // DESC, ASC + Collate string + Length int + priority int +} + +// ParseIndexes parse schema indexes +func (schema *Schema) ParseIndexes() map[string]Index { + var indexes = map[string]Index{} + + for _, field := range schema.Fields { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { + for _, index := range parseFieldIndexes(field) { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } + if idx.Option == "" { + idx.Option = index.Option + } + + idx.Fields = append(idx.Fields, index.Fields...) + sort.Slice(idx.Fields, func(i, j int) bool { + return idx.Fields[i].priority < idx.Fields[j].priority + }) + + indexes[index.Name] = idx + } + } + } + + return indexes +} + +func (schema *Schema) LookIndex(name string) *Index { + if schema != nil { + indexes := schema.ParseIndexes() + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + } + + return nil +} + +func parseFieldIndexes(field *Field) (indexes []Index) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUEINDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + settings = ParseTagSetting(tag, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) + ) + + if idx == -1 { + idx = len(tag) + } + + if idx != -1 { + name = tag[0:idx] + } + + if name == "" { + name = field.Schema.namer.IndexName(field.Schema.Table, field.Name) + } + + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + priority, err := strconv.Atoi(settings["PRIORITY"]) + if err != nil { + priority = 10 + } + + indexes = append(indexes, Index{ + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], + Option: settings["OPTION"], + Fields: []IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Length: length, + priority: priority, + }}, + }) + } + } + } + + return +} diff --git a/schema/index_test.go b/schema/index_test.go new file mode 100644 index 00000000..bc6bb8b6 --- /dev/null +++ b/schema/index_test.go @@ -0,0 +1,116 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "gorm.io/gorm/schema" +) + +type UserIndex struct { + Name string `gorm:"index"` + Name2 string `gorm:"index:idx_name,unique"` + Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` + Name4 string `gorm:"uniqueIndex"` + Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` + Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` + Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` + OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` + MemberNumber string `gorm:"index:idx_id,priority:1"` +} + +func TestParseIndex(t *testing.T) { + user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user index, got error %v", err) + } + + results := map[string]schema.Index{ + "idx_user_indices_name": { + Name: "idx_user_indices_name", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, + }, + "idx_name": { + Name: "idx_name", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}}, + }, + "idx_user_indices_name3": { + Name: "idx_user_indices_name3", + Type: "btree", + Where: "name3 != 'jinzhu'", + Fields: []schema.IndexOption{{ + Field: &schema.Field{Name: "Name3"}, + Sort: "desc", + Collate: "utf8", + Length: 10, + }}, + }, + "idx_user_indices_name4": { + Name: "idx_user_indices_name4", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}}, + }, + "idx_user_indices_name5": { + Name: "idx_user_indices_name5", + Class: "FULLTEXT", + Comment: "hello , world", + Where: "age > 10", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, + }, + "profile": { + Name: "profile", + Comment: "hello , world", + Where: "age > 10", + Option: "WITH PARSER parser_name", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, { + Field: &schema.Field{Name: "Age"}, + Expression: "ABS(age)", + }}, + }, + "idx_id": { + Name: "idx_id", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}}, + }, + "idx_oid": { + Name: "idx_oid", + Class: "UNIQUE", + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}}, + }, + } + + indices := user.ParseIndexes() + + for k, result := range results { + v, ok := indices[k] + if !ok { + t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices) + } + + for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} { + if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { + t.Errorf( + "index %v %v should equal, expects %v, got %v", + k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), + ) + } + } + + for idx, ef := range result.Fields { + rf := v.Fields[idx] + if rf.Field.Name != ef.Field.Name { + t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name) + } + + for _, name := range []string{"Expression", "Sort", "Collate", "Length"} { + if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() { + t.Errorf( + "index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name, + reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(), + ) + } + } + } + } +} diff --git a/schema/interfaces.go b/schema/interfaces.go new file mode 100644 index 00000000..98abffbd --- /dev/null +++ b/schema/interfaces.go @@ -0,0 +1,25 @@ +package schema + +import ( + "gorm.io/gorm/clause" +) + +type GormDataTypeInterface interface { + GormDataType() string +} + +type CreateClausesInterface interface { + CreateClauses(*Field) []clause.Interface +} + +type QueryClausesInterface interface { + QueryClauses(*Field) []clause.Interface +} + +type UpdateClausesInterface interface { + UpdateClauses(*Field) []clause.Interface +} + +type DeleteClausesInterface interface { + DeleteClauses(*Field) []clause.Interface +} diff --git a/schema/model_test.go b/schema/model_test.go new file mode 100644 index 00000000..1f2b0948 --- /dev/null +++ b/schema/model_test.go @@ -0,0 +1,62 @@ +package schema_test + +import ( + "database/sql" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +type User struct { + *gorm.Model + Name *string + Age *uint + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []*tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *uint + Manager *User + Team []*User `gorm:"foreignkey:ManagerID"` + Languages []*tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool +} + +type mytime time.Time +type myint int +type mybool = bool + +type AdvancedDataTypeUser struct { + ID sql.NullInt64 + Name *sql.NullString + Birthday sql.NullTime + RegisteredAt mytime + DeletedAt *mytime + Active mybool + Admin *mybool +} + +type BaseModel struct { + ID uint + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/naming.go b/schema/naming.go new file mode 100644 index 00000000..d53942e4 --- /dev/null +++ b/schema/naming.go @@ -0,0 +1,154 @@ +package schema + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "strings" + "unicode/utf8" + + "github.com/jinzhu/inflection" +) + +// Namer namer interface +type Namer interface { + TableName(table string) string + ColumnName(table, column string) string + JoinTableName(joinTable string) string + RelationshipFKName(Relationship) string + CheckerName(table, column string) string + IndexName(table, column string) string +} + +// Replacer replacer interface like strings.Replacer +type Replacer interface { + Replace(name string) string +} + +// NamingStrategy tables, columns naming strategy +type NamingStrategy struct { + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool +} + +// TableName convert string to table name +func (ns NamingStrategy) TableName(str string) string { + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// ColumnName convert string to column name +func (ns NamingStrategy) ColumnName(table, column string) string { + return ns.toDBName(column) +} + +// JoinTableName convert string to join table name +func (ns NamingStrategy) JoinTableName(str string) string { + if !ns.NoLowerCase && strings.ToLower(str) == str { + return ns.TablePrefix + str + } + + if ns.SingularTable { + return ns.TablePrefix + ns.toDBName(str) + } + return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) +} + +// RelationshipFKName generate fk name for relation +func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { + return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) +} + +// CheckerName generate checker name +func (ns NamingStrategy) CheckerName(table, column string) string { + return ns.formatName("chk", table, column) +} + +// IndexName generate index name +func (ns NamingStrategy) IndexName(table, column string) string { + return ns.formatName("idx", table, ns.toDBName(column)) +} + +func (ns NamingStrategy) formatName(prefix, table, name string) string { + formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) + + if utf8.RuneCountInString(formattedName) > 64 { + h := sha1.New() + h.Write([]byte(formattedName)) + bs := h.Sum(nil) + + formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] + } + return formattedName +} + +var ( + // https://github.com/golang/lint/blob/master/lint.go#L770 + commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} + commonInitialismsReplacer *strings.Replacer +) + +func init() { + commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) + for _, initialism := range commonInitialisms { + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + } + commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) +} + +func (ns NamingStrategy) toDBName(name string) string { + if name == "" { + return "" + } + + if ns.NameReplacer != nil { + name = ns.NameReplacer.Replace(name) + } + + if ns.NoLowerCase { + return name + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf strings.Builder + lastCase, nextCase, nextNumber bool // upper case == true + curCase = value[0] <= 'Z' && value[0] >= 'A' + ) + + for i, v := range value[:len(value)-1] { + nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' + nextNumber = value[i+1] >= '0' && value[i+1] <= '9' + + if curCase { + if lastCase && (nextCase || nextNumber) { + buf.WriteRune(v + 32) + } else { + if i > 0 && value[i-1] != '_' && value[i+1] != '_' { + buf.WriteByte('_') + } + buf.WriteRune(v + 32) + } + } else { + buf.WriteRune(v) + } + + lastCase = curCase + curCase = nextCase + } + + if curCase { + if !lastCase && len(value) > 1 { + buf.WriteByte('_') + } + buf.WriteByte(value[len(value)-1] + 32) + } else { + buf.WriteByte(value[len(value)-1]) + } + ret := buf.String() + return ret +} diff --git a/schema/naming_test.go b/schema/naming_test.go new file mode 100644 index 00000000..face9364 --- /dev/null +++ b/schema/naming_test.go @@ -0,0 +1,179 @@ +package schema + +import ( + "strings" + "testing" +) + +func TestToDBName(t *testing.T) { + var maps = map[string]string{ + "": "", + "x": "x", + "X": "x", + "userRestrictions": "user_restrictions", + "ThisIsATest": "this_is_a_test", + "PFAndESI": "pf_and_esi", + "AbcAndJkl": "abc_and_jkl", + "EmployeeID": "employee_id", + "SKU_ID": "sku_id", + "FieldX": "field_x", + "HTTPAndSMTP": "http_and_smtp", + "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", + "UUID": "uuid", + "HTTPURL": "http_url", + "HTTP_URL": "http_url", + "SHA256Hash": "sha256_hash", + "SHA256HASH": "sha256_hash", + "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", + } + + ns := NamingStrategy{} + for key, value := range maps { + if ns.toDBName(key) != value { + t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) + } + } +} + +func TestNamingStrategy(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: strings.NewReplacer("CID", "Cid"), + } + idxName := ns.IndexName("public.table", "name") + + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +type CustomReplacer struct { + f func(string) string +} + +func (r CustomReplacer) Replace(name string) string { + return r.f(name) +} + +func TestCustomReplacer(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: false, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_replaced_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.replaced_userlanguage" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.replaced_company" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "replaced_name_cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestCustomReplacerWithNoLowerCase(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + NameReplacer: CustomReplacer{ + func(name string) string { + replaced := "REPLACED_" + strings.ToUpper(name) + return strings.NewReplacer("CID", "_Cid").Replace(replaced) + }, + }, + NoLowerCase: true, + } + + idxName := ns.IndexName("public.table", "name") + if idxName != "idx_public_table_REPLACED_NAME" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.REPLACED_USER_LANGUAGES" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.REPLACED_USERLANGUAGE" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.REPLACED_COMPANY" { + t.Errorf("invalid table name generated, got %v", tableName) + } + + columdName := ns.ColumnName("", "NameCID") + if columdName != "REPLACED_NAME_Cid" { + t.Errorf("invalid column name generated, got %v", columdName) + } +} + +func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { + var ns = NamingStrategy{} + + formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") + if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + t.Errorf("invalid formatted name generated, got %v", formattedName) + } +} diff --git a/schema/relationship.go b/schema/relationship.go new file mode 100644 index 00000000..061e9120 --- /dev/null +++ b/schema/relationship.go @@ -0,0 +1,616 @@ +package schema + +import ( + "fmt" + "reflect" + "strings" + + "github.com/jinzhu/inflection" + "gorm.io/gorm/clause" +) + +// RelationshipType relationship type +type RelationshipType string + +const ( + HasOne RelationshipType = "has_one" // HasOneRel has one relationship + HasMany RelationshipType = "has_many" // HasManyRel has many relationship + BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship + Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship + has RelationshipType = "has" +) + +type Relationships struct { + HasOne []*Relationship + BelongsTo []*Relationship + HasMany []*Relationship + Many2Many []*Relationship + Relations map[string]*Relationship +} + +type Relationship struct { + Name string + Type RelationshipType + Field *Field + Polymorphic *Polymorphic + References []*Reference + Schema *Schema + FieldSchema *Schema + JoinTable *Schema + foreignKeys, primaryKeys []string +} + +type Polymorphic struct { + PolymorphicID *Field + PolymorphicType *Field + Value string +} + +type Reference struct { + PrimaryKey *Field + PrimaryValue string + ForeignKey *Field + OwnPrimaryKey bool +} + +func (schema *Schema) parseRelation(field *Field) *Relationship { + var ( + err error + fieldValue = reflect.New(field.IndirectFieldType).Interface() + relation = &Relationship{ + Name: field.Name, + Field: field, + Schema: schema, + foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), + primaryKeys: toColumns(field.TagSettings["REFERENCES"]), + } + ) + + cacheStore := schema.cacheStore + + if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return nil + } + + if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + schema.buildPolymorphicRelation(relation, field, polymorphic) + } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + schema.buildMany2ManyRelation(relation, field, many2many) + } else { + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + schema.guessRelation(relation, field, guessGuess) + case reflect.Slice: + schema.guessRelation(relation, field, guessHas) + default: + schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + } + } + + if relation.Type == has { + // don't add relations to embedded schema, which might be shared + if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { + relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation + } + + switch field.IndirectFieldType.Kind() { + case reflect.Struct: + relation.Type = HasOne + case reflect.Slice: + relation.Type = HasMany + } + } + + if schema.err == nil { + schema.Relationships.Relations[relation.Name] = relation + switch relation.Type { + case HasOne: + schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) + case HasMany: + schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) + case BelongsTo: + schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) + case Many2Many: + schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) + } + } + + return relation +} + +// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } +func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { + relation.Polymorphic = &Polymorphic{ + Value: schema.Table, + PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"], + PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"], + } + + if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { + relation.Polymorphic.Value = strings.TrimSpace(value) + } + + if relation.Polymorphic.PolymorphicType == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + } + + if relation.Polymorphic.PolymorphicID == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + } + + if schema.err == nil { + relation.References = append(relation.References, &Reference{ + PrimaryValue: relation.Polymorphic.Value, + ForeignKey: relation.Polymorphic.PolymorphicType, + }) + + primaryKeyField := schema.PrioritizedPrimaryField + if len(relation.foreignKeys) > 0 { + if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) + } + } + + // use same data type for foreign keys + relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType + relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryKeyField, + ForeignKey: relation.Polymorphic.PolymorphicID, + OwnPrimaryKey: true, + }) + } + + relation.Type = has +} + +func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { + relation.Type = Many2Many + + var ( + err error + joinTableFields []reflect.StructField + fieldsMap = map[string]*Field{} + ownFieldsMap = map[string]bool{} // fix self join many2many + joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) + joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) + ) + + ownForeignFields := schema.PrimaryFields + refForeignFields := relation.FieldSchema.PrimaryFields + + if len(relation.foreignKeys) > 0 { + ownForeignFields = []*Field{} + for _, foreignKey := range relation.foreignKeys { + if field := schema.LookUpField(foreignKey); field != nil { + ownForeignFields = append(ownForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + if len(relation.primaryKeys) > 0 { + refForeignFields = []*Field{} + for _, foreignKey := range relation.primaryKeys { + if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { + refForeignFields = append(refForeignFields, field) + } else { + schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + return + } + } + } + + for idx, ownField := range ownForeignFields { + joinFieldName := strings.Title(schema.Name) + ownField.Name + if len(joinForeignKeys) > idx { + joinFieldName = strings.Title(joinForeignKeys[idx]) + } + + ownFieldsMap[joinFieldName] = true + fieldsMap[joinFieldName] = ownField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: ownField.StructField.PkgPath, + Type: ownField.StructField.Type, + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + for idx, relField := range refForeignFields { + joinFieldName := relation.FieldSchema.Name + relField.Name + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + if _, ok := ownFieldsMap[joinFieldName]; ok { + if field.Name != relation.FieldSchema.Name { + joinFieldName = inflection.Singular(field.Name) + relField.Name + } else { + joinFieldName += "Reference" + } + } + + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } + + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: strings.Title(schema.Name) + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { + schema.err = err + } + relation.JoinTable.Name = many2many + relation.JoinTable.Table = schema.namer.JoinTableName(many2many) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) + + relName := relation.Schema.Name + relRefName := relation.FieldSchema.Name + if relName == relRefName { + relRefName = relation.Field.Name + } + + if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { + relation.JoinTable.Relationships.Relations[relName] = &Relationship{ + Name: relName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.Schema, + } + } else { + relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} + } + + if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { + relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ + Name: relRefName, + Type: BelongsTo, + Schema: relation.JoinTable, + FieldSchema: relation.FieldSchema, + } + } else { + relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} + } + + // build references + for _, f := range relation.JoinTable.Fields { + if f.Creatable || f.Readable || f.Updatable { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + + if ownPrimaryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPrimaryKey: ownPrimaryField, + }) + } + } +} + +type guessLevel int + +const ( + guessGuess guessLevel = iota + guessBelongs + guessEmbeddedBelongs + guessHas + guessEmbeddedHas +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { + var ( + primaryFields, foreignFields []*Field + primarySchema, foreignSchema = schema, relation.FieldSchema + gl = cgl + ) + + if gl == guessGuess { + if field.Schema == relation.FieldSchema { + gl = guessBelongs + } else { + gl = guessHas + } + } + + reguessOrErr := func() { + switch cgl { + case guessGuess: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: + default: + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) + } + } + + switch gl { + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + } else { + reguessOrErr() + return + } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } + } + + if len(relation.foreignKeys) > 0 { + for _, foreignKey := range relation.foreignKeys { + if f := foreignSchema.LookUpField(foreignKey); f != nil { + foreignFields = append(foreignFields, f) + } else { + reguessOrErr() + return + } + } + } else { + var primaryFields []*Field + + if len(relation.primaryKeys) > 0 { + for _, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + primaryFields = append(primaryFields, f) + } + } + } else { + primaryFields = primarySchema.PrimaryFields + } + + for _, primaryField := range primaryFields { + lookUpName := primarySchema.Name + primaryField.Name + if gl == guessBelongs { + lookUpName = field.Name + primaryField.Name + } + + lookUpNames := []string{lookUpName} + if len(primaryFields) == 1 { + lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) + } + + for _, name := range lookUpNames { + if f := foreignSchema.LookUpField(name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + break + } + } + } + } + + if len(foreignFields) == 0 { + reguessOrErr() + return + } else if len(relation.primaryKeys) > 0 { + for idx, primaryKey := range relation.primaryKeys { + if f := primarySchema.LookUpField(primaryKey); f != nil { + if len(primaryFields) < idx+1 { + primaryFields = append(primaryFields, f) + } else if f != primaryFields[idx] { + reguessOrErr() + return + } + } else { + reguessOrErr() + return + } + } + } else if len(primaryFields) == 0 { + if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { + primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) + } else if len(primarySchema.PrimaryFields) == len(foreignFields) { + primaryFields = append(primaryFields, primarySchema.PrimaryFields...) + } else { + reguessOrErr() + return + } + } + + // build references + for idx, foreignField := range foreignFields { + // use same data type for foreign keys + foreignField.DataType = primaryFields[idx].DataType + foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: primaryFields[idx], + ForeignKey: foreignField, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), + }) + } + + if gl == guessHas || gl == guessEmbeddedHas { + relation.Type = has + } else { + relation.Type = BelongsTo + } +} + +type Constraint struct { + Name string + Field *Field + Schema *Schema + ForeignKeys []*Field + ReferenceSchema *Schema + References []*Field + OnDelete string + OnUpdate string +} + +func (rel *Relationship) ParseConstraint() *Constraint { + str := rel.Field.TagSettings["CONSTRAINT"] + if str == "-" { + return nil + } + + if rel.Type == BelongsTo { + for _, r := range rel.FieldSchema.Relationships.Relations { + if r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { + matched := true + for idx, ref := range r.References { + if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && + rel.References[idx].PrimaryValue == ref.PrimaryValue) { + matched = false + } + } + + if matched { + return nil + } + } + } + } + + var ( + name string + idx = strings.Index(str, ",") + settings = ParseTagSetting(str, ",") + ) + + // optimize match english letters and midline + // The following code is basically called in for. + // In order to avoid the performance problems caused by repeated compilation of regular expressions, + // it only needs to be done once outside, so optimization is done here. + if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { + name = str[0:idx] + } else { + name = rel.Schema.namer.RelationshipFKName(*rel) + } + + constraint := Constraint{ + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + } + + for _, ref := range rel.References { + if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { + constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) + constraint.References = append(constraint.References, ref.PrimaryKey) + + if ref.OwnPrimaryKey { + constraint.Schema = ref.ForeignKey.Schema + constraint.ReferenceSchema = rel.Schema + } else { + constraint.Schema = rel.Schema + constraint.ReferenceSchema = ref.PrimaryKey.Schema + } + } + } + + return &constraint +} + +func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) { + table := rel.FieldSchema.Table + foreignFields := []*Field{} + relForeignKeys := []string{} + + if rel.JoinTable != nil { + table = rel.JoinTable.Table + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, + }) + } + } + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + conds = append(conds, clause.Eq{ + Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + } + + _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields) + column, values := ToQueryValues(table, relForeignKeys, foreignValues) + + conds = append(conds, clause.IN{Column: column, Values: values}) + return +} diff --git a/schema/relationship_test.go b/schema/relationship_test.go new file mode 100644 index 00000000..2971698c --- /dev/null +++ b/schema/relationship_test.go @@ -0,0 +1,484 @@ +package schema_test + +import ( + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { + if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Errorf("Failed to parse schema") + } else { + for _, rel := range relations { + checkSchemaRelation(t, s, rel) + } + } +} + +func TestBelongsToOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileRefer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:ProfileID;References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestBelongsToWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + +func TestBelongsToWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type User struct { + gorm.Model + Profile Profile `gorm:"References:Refer"` + ProfileID int + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, + }) +} + +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + +func TestHasOneOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasOneWithOnlyReferences2(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile Profile `gorm:"References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profile []Profile `gorm:"ForeignKey:UserRefer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, + }) +} + +func TestHasManyOverrideReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserID uint + } + + type User struct { + gorm.Model + Refer string + Profile []Profile `gorm:"ForeignKey:UserID;References:Refer"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, + }) +} + +func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` + Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserReferID", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, + }, + }, Relation{ + Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"}, + References: []Reference{ + {"Refer", "User", "User_refer_id", "user_profiles2", "", true}, + {"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false}, + }, + }) +} + +func TestMany2ManyOverrideForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;References:UserRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"UserRefer", "Profile", "ProfileUserRefer", "user_profiles", "", false}, + }, + }) +} + +func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + +func TestBuildReadonlyMany2ManyRelation(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + +func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { + type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + } + + type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` + } + + checkStructRelation(t, &Blog{}, + Relation{ + Name: "Tags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "blog_tags", Table: "blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "blog_tags", "", true}, + {"ID", "Tag", "TagID", "blog_tags", "", false}, + {"Locale", "Tag", "TagLocale", "blog_tags", "", false}, + }, + }, + Relation{ + Name: "SharedTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "shared_blog_tags", Table: "shared_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "shared_blog_tags", "", true}, + {"ID", "Tag", "TagID", "shared_blog_tags", "", false}, + }, + }, + Relation{ + Name: "LocaleTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", + JoinTable: JoinTable{Name: "locale_blog_tags", Table: "locale_blog_tags"}, + References: []Reference{ + {"ID", "Blog", "BlogID", "locale_blog_tags", "", true}, + {"Locale", "Blog", "BlogLocale", "locale_blog_tags", "", true}, + {"ID", "Tag", "TagID", "locale_blog_tags", "", false}, + }, + }, + ) +} + +func TestMultipleMany2Many(t *testing.T) { + type Thing struct { + ID int + } + + type Person struct { + ID int + Likes []Thing `gorm:"many2many:likes"` + Dislikes []Thing `gorm:"many2many:dislikes"` + } + + checkStructRelation(t, &Person{}, + Relation{ + Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "likes", Table: "likes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "likes", "", true}, + {"ID", "Thing", "ThingID", "likes", "", false}, + }, + }, + Relation{ + Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "dislikes", "", true}, + {"ID", "Thing", "ThingID", "dislikes", "", false}, + }, + }, + ) +} + +func TestSelfReferentialMany2Many(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy int32 + Creators []User `gorm:"foreignKey:CreatedBy"` + AnotherPro interface{} `gorm:"-"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}}, + }) + + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema") + } + + relSchema := user.Relationships.Relations["Creators"].FieldSchema + if user != relSchema { + t.Fatalf("schema should be same, expects %p but got %p", user, relSchema) + } +} + +type CreatedByModel struct { + CreatedByID uint + CreatedBy *CreatedUser +} + +type CreatedUser struct { + gorm.Model + CreatedByModel +} + +func TestEmbeddedRelation(t *testing.T) { + checkStructRelation(t, &CreatedUser{}, Relation{ + Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser", + References: []Reference{ + {"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false}, + }, + }) + + userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema, got error %v", err) + } + + if len(userSchema.Relationships.Relations) != 1 { + t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations)) + } + + if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok { + if createdByRel.FieldSchema != userSchema { + t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema) + } + } else { + t.Fatalf("expects created by relations, but not found") + } +} + +func TestSameForeignKey(t *testing.T) { + type UserAux struct { + gorm.Model + Aux string + UUID string + } + + type User struct { + gorm.Model + Name string + UUID string + Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"` + } + + checkStructRelation(t, &User{}, + Relation{ + Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux", + References: []Reference{ + {"UUID", "User", "UUID", "UserAux", "", true}, + }, + }, + ) +} diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 00000000..d08842e6 --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,283 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "go/ast" + "reflect" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// ErrUnsupportedDataType unsupported data type +var ErrUnsupportedDataType = errors.New("unsupported data type") + +type Schema struct { + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + PrimaryFieldDBNames []string + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue []*Field // fields with default value assigned by database + Relationships Relationships + CreateClauses []clause.Interface + QueryClauses []clause.Interface + UpdateClauses []clause.Interface + DeleteClauses []clause.Interface + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + initialized chan struct{} + namer Namer + cacheStore *sync.Map +} + +func (schema Schema) String() string { + if schema.ModelType.Name() == "" { + return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + } + return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) +} + +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + +func (schema Schema) LookUpField(name string) *Field { + if field, ok := schema.FieldsByDBName[name]; ok { + return field + } + if field, ok := schema.FieldsByName[name]; ok { + return field + } + return nil +} + +type Tabler interface { + TableName() string +} + +// get data type from dialector +func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + s := v.(*Schema) + <-s.initialized + return s, s.err + } + + modelValue := reflect.New(modelType) + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } + + schema := &Schema{ + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { + schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) + } else { + schema.Fields = append(schema.Fields, field) + } + } + } + + for _, field := range schema.Fields { + if field.DBName == "" && field.DataType != "" { + field.DBName = namer.ColumnName(schema.Table, field.Name) + } + + if field.DBName != "" { + // nonexistence or shortest path or first appear prioritized if has permission + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } + schema.FieldsByDBName[field.DBName] = field + schema.FieldsByName[field.Name] = field + + if v != nil && v.PrimaryKey { + for idx, f := range schema.PrimaryFields { + if f == v { + schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) + } + } + } + + if field.PrimaryKey { + schema.PrimaryFields = append(schema.PrimaryFields, field) + } + } + } + + if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByName[field.Name] = field + } + + field.setupValuerAndSetter() + } + + prioritizedPrimaryField := schema.LookUpField("id") + if prioritizedPrimaryField == nil { + prioritizedPrimaryField = schema.LookUpField("ID") + } + + if prioritizedPrimaryField != nil { + if prioritizedPrimaryField.PrimaryKey { + schema.PrioritizedPrimaryField = prioritizedPrimaryField + } else if len(schema.PrimaryFields) == 0 { + prioritizedPrimaryField.PrimaryKey = true + schema.PrioritizedPrimaryField = prioritizedPrimaryField + schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) + } + } + + if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } + + for _, field := range schema.PrimaryFields { + schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) + } + + for _, field := range schema.FieldsByDBName { + if field.HasDefaultValue && field.DefaultValueInterface == nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + } + + if field := schema.PrioritizedPrimaryField; field != nil { + switch field.GORMDataType { + case Int, Uint: + if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + field.AutoIncrement = true + } + } + } + + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB) error": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + } + } + } + + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + <-s.initialized + return s, s.err + } + + defer close(schema.initialized) + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } else { + schema.FieldsByName[field.Name] = field + } + } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } + } + } + + return schema, schema.err +} + +func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + modelType := reflect.ValueOf(dest).Type() + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + } + + if v, ok := cacheStore.Load(modelType); ok { + return v.(*Schema), nil + } + + return Parse(dest, cacheStore, namer) +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go new file mode 100644 index 00000000..6d2bc664 --- /dev/null +++ b/schema/schema_helper_test.go @@ -0,0 +1,210 @@ +package schema_test + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) { + t.Run("CheckSchema/"+s.Name, func(t *testing.T) { + tests.AssertObjEqual(t, s, v, "Name", "Table") + + for idx, field := range primaryFields { + var found bool + for _, f := range s.PrimaryFields { + if f.Name == field { + found = true + } + } + + if idx == 0 { + if field != s.PrioritizedPrimaryField.Name { + t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) + } + } + + if !found { + t.Errorf("schema %v failed to found primary key: %v", s, field) + } + } + }) +} + +func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { + t.Run("CheckField/"+f.Name, func(t *testing.T) { + if fc != nil { + fc(f) + } + + if f.TagSettings == nil { + if f.Tag != "" { + f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";") + } else { + f.TagSettings = map[string]string{} + } + } + + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { + t.Errorf("schema %v failed to look up field with name %v", s, f.Name) + } else { + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") + + if f.DBName != "" { + if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + + for _, name := range []string{f.DBName, f.Name} { + if name != "" { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { + t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) + } + } + } + + if f.PrimaryKey { + var found bool + for _, primaryField := range s.PrimaryFields { + if primaryField == parsedField { + found = true + } + } + + if !found { + t.Errorf("schema %v doesn't include field %v", s, f.Name) + } + } + } + }) +} + +type Relation struct { + Name string + Type schema.RelationshipType + Schema string + FieldSchema string + Polymorphic Polymorphic + JoinTable JoinTable + References []Reference +} + +type Polymorphic struct { + ID string + Type string + Value string +} + +type JoinTable struct { + Name string + Table string + Fields []schema.Field +} + +type Reference struct { + PrimaryKey string + PrimarySchema string + ForeignKey string + ForeignSchema string + PrimaryValue string + OwnPrimaryKey bool +} + +func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { + t.Run("CheckRelation/"+relation.Name, func(t *testing.T) { + if r, ok := s.Relationships.Relations[relation.Name]; ok { + if r.Name != relation.Name { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) + } + + if r.Type != relation.Type { + t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) + } + + if r.Schema.Name != relation.Schema { + t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) + } + + if r.FieldSchema.Name != relation.FieldSchema { + t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name) + } + + if r.Polymorphic != nil { + if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { + t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) + } + + if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { + t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) + } + + if r.Polymorphic.Value != relation.Polymorphic.Value { + t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) + } + } + + if r.JoinTable != nil { + if r.JoinTable.Name != relation.JoinTable.Name { + t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) + } + + if r.JoinTable.Table != relation.JoinTable.Table { + t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) + } + + for _, f := range relation.JoinTable.Fields { + checkSchemaField(t, r.JoinTable, &f, nil) + } + } + + if len(relation.References) != len(r.References) { + t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) + } + + for _, ref := range relation.References { + var found bool + for _, rf := range r.References { + if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { + found = true + } + } + + if !found { + var refs []string + for _, rf := range r.References { + var primaryKey, primaryKeySchema string + if rf.PrimaryKey != nil { + primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name + } + refs = append(refs, fmt.Sprintf( + "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", + primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, + )) + } + t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) + } + } + } else { + t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) + } + }) +} + +func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { + for k, v := range values { + t.Run("CheckField/"+k, func(t *testing.T) { + fv, _ := s.FieldsByDBName[k].ValueOf(value) + tests.AssertEqual(t, v, fv) + }) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go new file mode 100644 index 00000000..a426cd90 --- /dev/null +++ b/schema/schema_test.go @@ -0,0 +1,299 @@ +package schema_test + +import ( + "strings" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils/tests" +) + +func TestParseSchema(t *testing.T) { + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func TestParseSchemaWithPointerFields(t *testing.T) { + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func checkUserSchema(t *testing.T, user *schema.Schema) { + // check schema + checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, + {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } + + // check relations + relations := []Relation{ + { + Name: "Account", Type: schema.HasOne, Schema: "User", FieldSchema: "Account", + References: []Reference{{"ID", "User", "UserID", "Account", "", true}}, + }, + { + Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", + References: []Reference{{"ID", "User", "UserID", "Pet", "", true}}, + }, + { + Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{{"ID", "User", "OwnerID", "Toy", "", true}, {"", "", "OwnerType", "Toy", "users", false}}, + }, + { + Name: "Company", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Company", + References: []Reference{{"ID", "Company", "CompanyID", "User", "", false}}, + }, + { + Name: "Manager", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", false}}, + }, + { + Name: "Team", Type: schema.HasMany, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "ManagerID", "User", "", true}}, + }, + { + Name: "Languages", Type: schema.Many2Many, Schema: "User", FieldSchema: "Language", + JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + { + Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, + }, + { + Name: "Friends", Type: schema.Many2Many, Schema: "User", FieldSchema: "User", + JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ + { + Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + { + Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, + Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, + }, + }}, + References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, + }, + } + + for _, relation := range relations { + checkSchemaRelation(t, user, relation) + } +} + +func TestParseSchemaWithAdvancedDataType(t *testing.T) { + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + // check schema + checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} + +type CustomizeTable struct { +} + +func (CustomizeTable) TableName() string { + return "customize" +} + +func TestCustomizeTableName(t *testing.T) { + customize, err := schema.Parse(&CustomizeTable{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + if customize.Table != "customize" { + t.Errorf("Failed to customize table with TableName method") + } +} + +func TestNestedModel(t *testing.T) { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} + +func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} diff --git a/schema/utils.go b/schema/utils.go new file mode 100644 index 00000000..add22047 --- /dev/null +++ b/schema/utils.go @@ -0,0 +1,197 @@ +package schema + +import ( + "reflect" + "regexp" + "strings" + + "gorm.io/gorm/clause" + "gorm.io/gorm/utils" +) + +var embeddedCacheKey = "embedded_cache_store" + +func ParseTagSetting(str string, sep string) map[string]string { + settings := map[string]string{} + names := strings.Split(str, sep) + + for i := 0; i < len(names); i++ { + j := i + if len(names[j]) > 0 { + for { + if names[j][len(names[j])-1] == '\\' { + i++ + names[j] = names[j][0:len(names[j])-1] + sep + names[i] + names[i] = "" + } else { + break + } + } + } + + values := strings.Split(names[j], ":") + k := strings.TrimSpace(strings.ToUpper(values[0])) + + if len(values) >= 2 { + settings[k] = strings.Join(values[1:], ":") + } else if k != "" { + settings[k] = k + } + } + + return settings +} + +func toColumns(val string) (results []string) { + if val != "" { + for _, v := range strings.Split(val, ",") { + results = append(results, strings.TrimSpace(v)) + } + } + return +} + +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag +} + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result.Addr()) + case reflect.Slice, reflect.Array: + for i := 0; i < result.Len(); i++ { + if elem := result.Index(i); elem.Kind() == reflect.Ptr { + reflectResults = reflect.Append(reflectResults, elem) + } else { + reflectResults = reflect.Append(reflectResults, elem.Addr()) + } + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + loaded = map[interface{}]bool{} + notZero, zero bool + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + results[0][idx], zero = field.ValueOf(reflectValue) + notZero = notZero || !zero + } + + if !notZero { + return nil, nil + } + + dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + elem := reflectValue.Index(i) + elemKey := elem.Interface() + if elem.Kind() != reflect.Ptr { + elemKey = elem.Addr().Interface() + } + + if _, ok := loaded[elemKey]; ok { + continue + } + loaded[elemKey] = true + + fieldValues := make([]interface{}, len(fields)) + notZero = false + for idx, field := range fields { + fieldValues[idx], zero = field.ValueOf(elem) + notZero = notZero || !zero + } + + if notZero { + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + results = append(results, fieldValues) + dataResults[dataKey] = []reflect.Value{elem} + } else { + dataResults[dataKey] = append(dataResults[dataKey], elem) + } + } + } + } + + return dataResults, results +} + +// GetIdentityFieldValuesMapFromValues get identity map from fields +func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + resultsMap := map[string][]reflect.Value{} + results := [][]interface{}{} + + for _, v := range values { + rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields) + for k, v := range rm { + resultsMap[k] = append(resultsMap[k], v...) + } + results = append(results, rs...) + } + return resultsMap, results +} + +// ToQueryValues to query values +func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues + } else { + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + return columns, queryValues + } +} + +type embeddedNamer struct { + Table string + Namer +} diff --git a/schema/utils_test.go b/schema/utils_test.go new file mode 100644 index 00000000..1b47ef25 --- /dev/null +++ b/schema/utils_test.go @@ -0,0 +1,24 @@ +package schema + +import ( + "reflect" + "testing" +) + +func TestRemoveSettingFromTag(t *testing.T) { + tags := map[string]string{ + `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, + `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + } + + for k, v := range tags { + if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { + t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) + } + } +} diff --git a/scope.go b/scope.go deleted file mode 100644 index 51ebd5a0..00000000 --- a/scope.go +++ /dev/null @@ -1,1327 +0,0 @@ -package gorm - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - "reflect" -) - -// Scope contain current operation's information when you perform any operation on the database -type Scope struct { - Search *search - Value interface{} - SQL string - SQLVars []interface{} - db *DB - instanceID string - primaryKeyField *Field - skipLeft bool - fields *[]*Field - selectAttrs *[]string -} - -// IndirectValue return scope's reflect value's indirect value -func (scope *Scope) IndirectValue() reflect.Value { - return indirect(reflect.ValueOf(scope.Value)) -} - -// New create a new Scope without search information -func (scope *Scope) New(value interface{}) *Scope { - return &Scope{db: scope.NewDB(), Search: &search{}, Value: value} -} - -//////////////////////////////////////////////////////////////////////////////// -// Scope DB -//////////////////////////////////////////////////////////////////////////////// - -// DB return scope's DB connection -func (scope *Scope) DB() *DB { - return scope.db -} - -// NewDB create a new DB without search information -func (scope *Scope) NewDB() *DB { - if scope.db != nil { - db := scope.db.clone() - db.search = nil - db.Value = nil - return db - } - return nil -} - -// SQLDB return *sql.DB -func (scope *Scope) SQLDB() SQLCommon { - return scope.db.db -} - -// Dialect get dialect -func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect -} - -// Quote used to quote string to escape them for database -func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { - newStrs := []string{} - for _, str := range strings.Split(str, ".") { - newStrs = append(newStrs, scope.Dialect().Quote(str)) - } - return strings.Join(newStrs, ".") - } - - return scope.Dialect().Quote(str) -} - -// Err add error to Scope -func (scope *Scope) Err(err error) error { - if err != nil { - scope.db.AddError(err) - } - return err -} - -// HasError check if there are any error -func (scope *Scope) HasError() bool { - return scope.db.Error != nil -} - -// Log print log message -func (scope *Scope) Log(v ...interface{}) { - scope.db.log(v...) -} - -// SkipLeft skip remaining callbacks -func (scope *Scope) SkipLeft() { - scope.skipLeft = true -} - -// Fields get value's fields -func (scope *Scope) Fields() []*Field { - if scope.fields == nil { - var ( - fields []*Field - indirectScopeValue = scope.IndirectValue() - isStruct = indirectScopeValue.Kind() == reflect.Struct - ) - - for _, structField := range scope.GetModelStruct().StructFields { - if isStruct { - fieldValue := indirectScopeValue - for _, name := range structField.Names { - fieldValue = reflect.Indirect(fieldValue).FieldByName(name) - } - fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) - } else { - fields = append(fields, &Field{StructField: structField, IsBlank: true}) - } - } - scope.fields = &fields - } - - return *scope.fields -} - -// FieldByName find `gorm.Field` with field name or db name -func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - - for _, field := range scope.Fields() { - if field.Name == name || field.DBName == name { - return field, true - } - if field.DBName == dbName { - mostMatchedField = field - } - } - return mostMatchedField, mostMatchedField != nil -} - -// PrimaryFields return scope's primary fields -func (scope *Scope) PrimaryFields() (fields []*Field) { - for _, field := range scope.Fields() { - if field.IsPrimaryKey { - fields = append(fields, field) - } - } - return fields -} - -// PrimaryField return scope's main primary field, if defined more that one primary fields, will return the one having column name `id` or the first one -func (scope *Scope) PrimaryField() *Field { - if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { - if len(primaryFields) > 1 { - if field, ok := scope.FieldByName("id"); ok { - return field - } - } - return scope.PrimaryFields()[0] - } - return nil -} - -// PrimaryKey get main primary field's db name -func (scope *Scope) PrimaryKey() string { - if field := scope.PrimaryField(); field != nil { - return field.DBName - } - return "" -} - -// PrimaryKeyZero check main primary field's value is blank or not -func (scope *Scope) PrimaryKeyZero() bool { - field := scope.PrimaryField() - return field == nil || field.IsBlank -} - -// PrimaryKeyValue get the primary key's value -func (scope *Scope) PrimaryKeyValue() interface{} { - if field := scope.PrimaryField(); field != nil && field.Field.IsValid() { - return field.Field.Interface() - } - return 0 -} - -// HasColumn to check if has column -func (scope *Scope) HasColumn(column string) bool { - for _, field := range scope.GetStructFields() { - if field.IsNormal && (field.Name == column || field.DBName == column) { - return true - } - } - return false -} - -// SetColumn to set the column's value, column could be field or field's name/dbname -func (scope *Scope) SetColumn(column interface{}, value interface{}) error { - var updateAttrs = map[string]interface{}{} - if attrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - updateAttrs = attrs.(map[string]interface{}) - defer scope.InstanceSet("gorm:update_attrs", updateAttrs) - } - - if field, ok := column.(*Field); ok { - updateAttrs[field.DBName] = value - return field.Set(value) - } else if name, ok := column.(string); ok { - var ( - dbName = ToDBName(name) - mostMatchedField *Field - ) - for _, field := range scope.Fields() { - if field.DBName == value { - updateAttrs[field.DBName] = value - return field.Set(value) - } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { - mostMatchedField = field - } - } - - if mostMatchedField != nil { - updateAttrs[mostMatchedField.DBName] = value - return mostMatchedField.Set(value) - } - } - return errors.New("could not convert column to field") -} - -// CallMethod call scope value's method, if it is a slice, will call its element's method one by one -func (scope *Scope) CallMethod(methodName string) { - if scope.Value == nil { - return - } - - if indirectScopeValue := scope.IndirectValue(); indirectScopeValue.Kind() == reflect.Slice { - for i := 0; i < indirectScopeValue.Len(); i++ { - scope.callMethod(methodName, indirectScopeValue.Index(i)) - } - } else { - scope.callMethod(methodName, indirectScopeValue) - } -} - -// AddToVars add value as sql's vars, used to prevent SQL injection -func (scope *Scope) AddToVars(value interface{}) string { - _, skipBindVar := scope.InstanceGet("skip_bindvar") - - if expr, ok := value.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - if skipBindVar { - scope.AddToVars(arg) - } else { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - } - return exp - } - - scope.SQLVars = append(scope.SQLVars, value) - - if skipBindVar { - return "?" - } - return scope.Dialect().BindVar(len(scope.SQLVars)) -} - -// SelectAttrs return selected attributes -func (scope *Scope) SelectAttrs() []string { - if scope.selectAttrs == nil { - attrs := []string{} - for _, value := range scope.Search.selects { - if str, ok := value.(string); ok { - attrs = append(attrs, str) - } else if strs, ok := value.([]string); ok { - attrs = append(attrs, strs...) - } else if strs, ok := value.([]interface{}); ok { - for _, str := range strs { - attrs = append(attrs, fmt.Sprintf("%v", str)) - } - } - } - scope.selectAttrs = &attrs - } - return *scope.selectAttrs -} - -// OmitAttrs return omitted attributes -func (scope *Scope) OmitAttrs() []string { - return scope.Search.omits -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*DB) string -} - -// TableName return table name -func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName - } - - if tabler, ok := scope.Value.(tabler); ok { - return tabler.TableName() - } - - if tabler, ok := scope.Value.(dbTabler); ok { - return tabler.TableName(scope.db) - } - - return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) -} - -// QuotedTableName return quoted table name -func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { - return scope.Search.tableName - } - return scope.Quote(scope.Search.tableName) - } - - return scope.Quote(scope.TableName()) -} - -// CombinedConditionSql return combined condition sql -func (scope *Scope) CombinedConditionSql() string { - joinSQL := scope.joinsSQL() - whereSQL := scope.whereSQL() - if scope.Search.raw { - whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") - } - return joinSQL + whereSQL + scope.groupSQL() + - scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() -} - -// Raw set raw sql -func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$$", "?", -1) - return scope -} - -// Exec perform generated SQL -func (scope *Scope) Exec() *Scope { - defer scope.trace(NowFunc()) - - if !scope.HasError() { - if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); scope.Err(err) == nil { - scope.db.RowsAffected = count - } - } - } - return scope -} - -// Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope.db.InstantSet(name, value) - return scope -} - -// Get get setting by name -func (scope *Scope) Get(name string) (interface{}, bool) { - return scope.db.Get(name) -} - -// InstanceID get InstanceID for scope -func (scope *Scope) InstanceID() string { - if scope.instanceID == "" { - scope.instanceID = fmt.Sprintf("%v%v", &scope, &scope.db) - } - return scope.instanceID -} - -// InstanceSet set instance setting for current operation, but not for operations in callbacks, like saving associations callback -func (scope *Scope) InstanceSet(name string, value interface{}) *Scope { - return scope.Set(name+scope.InstanceID(), value) -} - -// InstanceGet get instance setting from current operation -func (scope *Scope) InstanceGet(name string) (interface{}, bool) { - return scope.Get(name + scope.InstanceID()) -} - -// Begin start a transaction -func (scope *Scope) Begin() *Scope { - if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(SQLCommon) - scope.InstanceSet("gorm:started_transaction", true) - } - } - return scope -} - -// CommitOrRollback commit current transaction if no error happened, otherwise will rollback it -func (scope *Scope) CommitOrRollback() *Scope { - if _, ok := scope.InstanceGet("gorm:started_transaction"); ok { - if db, ok := scope.db.db.(sqlTx); ok { - if scope.HasError() { - db.Rollback() - } else { - scope.Err(db.Commit()) - } - scope.db.db = scope.db.parent.db - } - } - return scope -} - -//////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.Scope -//////////////////////////////////////////////////////////////////////////////// - -func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { - // Only get address from non-pointer - if reflectValue.CanAddr() && reflectValue.Kind() != reflect.Ptr { - reflectValue = reflectValue.Addr() - } - - if methodValue := reflectValue.MethodByName(methodName); methodValue.IsValid() { - switch method := methodValue.Interface().(type) { - case func(): - method() - case func(*Scope): - method(scope) - case func(*DB): - newDB := scope.NewDB() - method(newDB) - scope.Err(newDB.Error) - case func() error: - scope.Err(method()) - case func(*Scope) error: - scope.Err(method(scope)) - case func(*DB) error: - newDB := scope.NewDB() - scope.Err(method(newDB)) - scope.Err(newDB.Error) - default: - scope.Err(fmt.Errorf("unsupported function %v", methodName)) - } - } -} - -var ( - columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` - isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number - comparisonRegexp = regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ") - countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") -) - -func (scope *Scope) quoteIfPossible(str string) string { - if columnRegexp.MatchString(str) { - return scope.Quote(str) - } - return str -} - -func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { - var ( - ignored interface{} - values = make([]interface{}, len(columns)) - selectFields []*Field - selectedColumnsMap = map[string]int{} - resetFields = map[int]*Field{} - ) - - for index, column := range columns { - values[index] = &ignored - - selectFields = fields - if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] - } - - for fieldIndex, field := range selectFields { - if field.DBName == column { - if field.Field.Kind() == reflect.Ptr { - values[index] = field.Field.Addr().Interface() - } else { - reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) - reflectValue.Elem().Set(field.Field.Addr()) - values[index] = reflectValue.Interface() - resetFields[index] = field - } - - selectedColumnsMap[column] = fieldIndex - - if field.IsNormal { - break - } - } - } - } - - scope.Err(rows.Scan(values...)) - - for index, field := range resetFields { - if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { - field.Field.Set(v) - } - } -} - -func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) -} - -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v.%v IN (?))", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey())) - clause["args"] = []interface{}{value} - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - newScope := scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - - switch value := clause["query"].(type) { - case string: - if isNumberRegexp.MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if comparisonRegexp.MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v.%v <> ?)", scope.QuotedTableName(), scope.Quote(value)) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v.%v NOT IN (?))", scope.QuotedTableName(), scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } else { - return "" - } - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", scope.QuotedTableName(), scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - var newScope = scope.New(value) - for _, field := range newScope.Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v <> %v)", scope.QuotedTableName(), scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) { - switch value := clause["query"].(type) { - case string: - str = value - case []string: - str = strings.Join(value, ", ") - } - - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: - values := reflect.ValueOf(arg) - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - default: - if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() - } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) - } - } - return -} - -func (scope *Scope) whereSQL() (sql string) { - var ( - quotedTableName = scope.QuotedTableName() - deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") - primaryConditions, andConditions, orConditions []string - ) - - if !scope.Search.Unscoped && hasDeletedAtField { - sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) - primaryConditions = append(primaryConditions, sql) - } - - if !scope.PrimaryKeyZero() { - for _, field := range scope.PrimaryFields() { - sql := fmt.Sprintf("%v.%v = %v", quotedTableName, scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())) - primaryConditions = append(primaryConditions, sql) - } - } - - for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - orConditions = append(orConditions, sql) - } - } - - for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - orSQL := strings.Join(orConditions, " OR ") - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) > 0 { - if len(orSQL) > 0 { - combinedSQL = combinedSQL + " OR " + orSQL - } - } else { - combinedSQL = orSQL - } - - if len(primaryConditions) > 0 { - sql = "WHERE " + strings.Join(primaryConditions, " AND ") - if len(combinedSQL) > 0 { - sql = sql + " AND (" + combinedSQL + ")" - } - } else if len(combinedSQL) > 0 { - sql = "WHERE " + combinedSQL - } - return -} - -func (scope *Scope) selectSQL() string { - if len(scope.Search.selects) == 0 { - if len(scope.Search.joinConditions) > 0 { - return fmt.Sprintf("%v.*", scope.QuotedTableName()) - } - return "*" - } - return scope.buildSelectQuery(scope.Search.selects) -} - -func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { - return "" - } - - var orders []string - for _, order := range scope.Search.orders { - if str, ok := order.(string); ok { - orders = append(orders, scope.quoteIfPossible(str)) - } else if expr, ok := order.(*expr); ok { - exp := expr.expr - for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) - } - orders = append(orders, exp) - } - } - return " ORDER BY " + strings.Join(orders, ",") -} - -func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) -} - -func (scope *Scope) groupSQL() string { - if len(scope.Search.group) == 0 { - return "" - } - return " GROUP BY " + scope.Search.group -} - -func (scope *Scope) havingSQL() string { - if len(scope.Search.havingConditions) == 0 { - return "" - } - - var andConditions []string - for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - andConditions = append(andConditions, sql) - } - } - - combinedSQL := strings.Join(andConditions, " AND ") - if len(combinedSQL) == 0 { - return "" - } - - return " HAVING " + combinedSQL -} - -func (scope *Scope) joinsSQL() string { - var joinConditions []string - for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { - joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) - } - } - - return strings.Join(joinConditions, " ") + " " -} - -func (scope *Scope) prepareQuerySQL() { - if scope.Search.raw { - scope.Raw(scope.CombinedConditionSql()) - } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) - } - return -} - -func (scope *Scope) inlineCondition(values ...interface{}) *Scope { - if len(values) > 0 { - scope.Search.Where(values[0], values[1:]...) - } - return scope -} - -func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { - for _, f := range funcs { - (*f)(scope) - if scope.skipLeft { - break - } - } - return scope -} - -func convertInterfaceToMap(values interface{}, withIgnoredField bool) map[string]interface{} { - var attrs = map[string]interface{}{} - - switch value := values.(type) { - case map[string]interface{}: - return value - case []interface{}: - for _, v := range value { - for key, value := range convertInterfaceToMap(v, withIgnoredField) { - attrs[key] = value - } - } - case interface{}: - reflectValue := reflect.ValueOf(values) - - switch reflectValue.Kind() { - case reflect.Map: - for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() - } - default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { - attrs[field.DBName] = field.Field.Interface() - } - } - } - } - return attrs -} - -func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { - if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value, false), true - } - - results = map[string]interface{}{} - - for key, value := range convertInterfaceToMap(value, true) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal { - hasUpdate = true - if err == ErrUnaddressable { - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() - } - } - } - } - } - return -} - -func (scope *Scope) row() *sql.Row { - defer scope.trace(NowFunc()) - - result := &RowQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Row -} - -func (scope *Scope) rows() (*sql.Rows, error) { - defer scope.trace(NowFunc()) - - result := &RowsQueryResult{} - scope.InstanceSet("row_query_result", result) - scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - - return result.Rows, result.Error -} - -func (scope *Scope) initialize() *Scope { - for _, clause := range scope.Search.whereConditions { - scope.updatedAttrsWithValues(clause["query"]) - } - scope.updatedAttrsWithValues(scope.Search.initAttrs) - scope.updatedAttrsWithValues(scope.Search.assignAttrs) - return scope -} - -func (scope *Scope) pluck(column string, value interface{}) *Scope { - dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) - if dest.Kind() != reflect.Slice { - scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) - return scope - } - - rows, err := scope.rows() - if scope.Err(err) == nil { - defer rows.Close() - for rows.Next() { - elem := reflect.New(dest.Type().Elem()).Interface() - scope.Err(rows.Scan(elem)) - dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) - } - - if err := rows.Err(); err != nil { - scope.Err(err) - } - } - return scope -} - -func (scope *Scope) count(value interface{}) *Scope { - if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { - scope.Search.Select("count(*)") - } - scope.Search.ignoreOrderQuery = true - scope.Err(scope.row().Scan(value)) - return scope -} - -func (scope *Scope) typeName() string { - typ := scope.IndirectValue().Type() - - for typ.Kind() == reflect.Slice || typ.Kind() == reflect.Ptr { - typ = typ.Elem() - } - - return typ.Name() -} - -// trace print sql log -func (scope *Scope) trace(t time.Time) { - if len(scope.SQL) > 0 { - scope.db.slog(scope.SQL, t, scope.SQLVars...) - } -} - -func (scope *Scope) changeableField(field *Field) bool { - if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 { - for _, attr := range selectAttrs { - if field.Name == attr || field.DBName == attr { - return true - } - } - return false - } - - for _, attr := range scope.OmitAttrs() { - if field.Name == attr || field.DBName == attr { - return false - } - } - - return true -} - -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok { - if v, ok := saveAssociations.(bool); ok && !v { - return false - } - if v, ok := saveAssociations.(string); ok && (v != "skip") { - return false - } - } - return true && !scope.HasError() -} - -func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { - toScope := scope.db.NewScope(value) - tx := scope.db.Set("gorm:association:source", scope.Value) - - for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { - fromField, _ := scope.FieldByName(foreignKey) - toField, _ := toScope.FieldByName(foreignKey) - - if fromField != nil { - if relationship := fromField.Relationship; relationship != nil { - if relationship.Kind == "many_to_many" { - joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) - } else if relationship.Kind == "belongs_to" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(foreignKey); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) - } - } - scope.Err(tx.Find(value).Error) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) - } - } - - if relationship.PolymorphicType != "" { - tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) - } - scope.Err(tx.Find(value).Error) - } - } else { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) - } - return scope - } else if toField != nil { - sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) - return scope - } - } - - scope.Err(fmt.Errorf("invalid association %v", foreignKeys)) - return scope -} - -// getTableOptions return the table options string or an empty string if the table options does not exist -func (scope *Scope) getTableOptions() string { - tableOptions, ok := scope.Get("gorm:table_options") - if !ok { - return "" - } - return tableOptions.(string) -} - -func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { - joinTableHandler := relationship.JoinTableHandler - joinTable := joinTableHandler.Table(scope.db) - if !scope.Dialect().HasTable(joinTable) { - toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} - - var sqlTypes, primaryKeys []string - for idx, fieldName := range relationship.ForeignFieldNames { - if field, ok := scope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) - } - } - - for idx, fieldName := range relationship.AssociationForeignFieldNames { - if field, ok := toScope.FieldByName(fieldName); ok { - foreignKeyStruct := field.clone() - foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" - delete(foreignKeyStruct.TagSettings, "AUTO_INCREMENT") - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) - primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) - } - } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) - } - scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) - } -} - -func (scope *Scope) createTable() *Scope { - var tags []string - var primaryKeys []string - var primaryKeyInColumnType = false - for _, field := range scope.GetModelStruct().StructFields { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - - // Check if the primary key constraint was specified as - // part of the column type. If so, we can only support - // one column as the primary key. - if strings.Contains(strings.ToLower(sqlTag), "primary key") { - primaryKeyInColumnType = true - } - - tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag) - } - - if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) - } - scope.createJoinTable(field) - } - - var primaryKeyStr string - if len(primaryKeys) > 0 && !primaryKeyInColumnType { - primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) - } - - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() - - scope.autoIndex() - return scope -} - -func (scope *Scope) dropTable() *Scope { - scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec() - return scope -} - -func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() -} - -func (scope *Scope) dropColumn(column string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec() -} - -func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { - if scope.Dialect().HasIndex(scope.TableName(), indexName) { - return - } - - var columns []string - for _, name := range column { - columns = append(columns, scope.quoteIfPossible(name)) - } - - sqlCreate := "CREATE INDEX" - if unique { - sqlCreate = "CREATE UNIQUE INDEX" - } - - scope.Raw(fmt.Sprintf("%s %v ON %v(%v) %v", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "), scope.whereSQL())).Exec() -} - -func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - keyName := scope.Dialect().BuildForeignKeyName(scope.TableName(), field, dest) - - if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { - return - } - var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() -} - -func (scope *Scope) removeIndex(indexName string) { - scope.Dialect().RemoveIndex(scope.TableName(), indexName) -} - -func (scope *Scope) autoMigrate() *Scope { - tableName := scope.TableName() - quotedTableName := scope.QuotedTableName() - - if !scope.Dialect().HasTable(tableName) { - scope.createTable() - } else { - for _, field := range scope.GetModelStruct().StructFields { - if !scope.Dialect().HasColumn(tableName, field.DBName) { - if field.IsNormal { - sqlTag := scope.Dialect().DataTypeOf(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() - } - } - scope.createJoinTable(field) - } - scope.autoIndex() - } - return scope -} - -func (scope *Scope) autoIndex() *Scope { - var indexes = map[string][]string{} - var uniqueIndexes = map[string][]string{} - - for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "INDEX" || name == "" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) - } - indexes[name] = append(indexes[name], field.DBName) - } - } - - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - names := strings.Split(name, ",") - - for _, name := range names { - if name == "UNIQUE_INDEX" || name == "" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) - } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) - } - } - } - - for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) - } - - for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) - } - - return scope -} - -func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { - for _, value := range values { - indirectValue := indirect(reflect.ValueOf(value)) - - switch indirectValue.Kind() { - case reflect.Slice: - for i := 0; i < indirectValue.Len(); i++ { - var result []interface{} - var object = indirect(indirectValue.Index(i)) - var hasValue = false - for _, column := range columns { - field := object.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - results = append(results, result) - } - } - case reflect.Struct: - var result []interface{} - var hasValue = false - for _, column := range columns { - field := indirectValue.FieldByName(column) - if hasValue || !isBlank(field) { - hasValue = true - } - result = append(result, field.Interface()) - } - - if hasValue { - results = append(results, result) - } - } - } - - return -} - -func (scope *Scope) getColumnAsScope(column string) *Scope { - indirectScopeValue := scope.IndirectValue() - - switch indirectScopeValue.Kind() { - case reflect.Slice: - if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { - fieldType := fieldStruct.Type - if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - - resultsMap := map[interface{}]bool{} - results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() - - for i := 0; i < indirectScopeValue.Len(); i++ { - result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) - - if result.Kind() == reflect.Slice { - for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { - resultsMap[elem.Addr()] = true - results = reflect.Append(results, elem.Addr()) - } - } - } else if result.CanAddr() && resultsMap[result.Addr()] != true { - resultsMap[result.Addr()] = true - results = reflect.Append(results, result.Addr()) - } - } - return scope.New(results.Interface()) - } - case reflect.Struct: - if field := indirectScopeValue.FieldByName(column); field.CanAddr() { - return scope.New(field.Addr().Interface()) - } - } - return nil -} - -func (scope *Scope) hasConditions() bool { - return !scope.PrimaryKeyZero() || - len(scope.Search.whereConditions) > 0 || - len(scope.Search.orConditions) > 0 || - len(scope.Search.notConditions) > 0 -} diff --git a/scope_test.go b/scope_test.go deleted file mode 100644 index 42458995..00000000 --- a/scope_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package gorm_test - -import ( - "github.com/jinzhu/gorm" - "testing" -) - -func NameIn1And2(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) -} - -func NameIn2And3(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) -} - -func NameIn(names []string) func(d *gorm.DB) *gorm.DB { - return func(d *gorm.DB) *gorm.DB { - return d.Where("name in (?)", names) - } -} - -func TestScopes(t *testing.T) { - user1 := User{Name: "ScopeUser1", Age: 1} - user2 := User{Name: "ScopeUser2", Age: 1} - user3 := User{Name: "ScopeUser3", Age: 2} - DB.Save(&user1).Save(&user2).Save(&user3) - - var users1, users2, users3 []User - DB.Scopes(NameIn1And2).Find(&users1) - if len(users1) != 2 { - t.Errorf("Should found two users's name in 1, 2") - } - - DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) - if len(users2) != 1 { - t.Errorf("Should found one user's name is 2") - } - - DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3) - if len(users3) != 2 { - t.Errorf("Should found two users's name in 1, 3") - } -} diff --git a/search.go b/search.go deleted file mode 100644 index 90138595..00000000 --- a/search.go +++ /dev/null @@ -1,153 +0,0 @@ -package gorm - -import ( - "fmt" -) - -type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingConditions []map[string]interface{} - joinConditions []map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []interface{} - preload []searchPreload - offset interface{} - limit interface{} - group string - tableName string - raw bool - Unscoped bool - ignoreOrderQuery bool -} - -type searchPreload struct { - schema string - conditions []interface{} -} - -func (s *search) clone() *search { - clone := *s - return &clone -} - -func (s *search) Where(query interface{}, values ...interface{}) *search { - s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Not(query interface{}, values ...interface{}) *search { - s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Or(query interface{}, values ...interface{}) *search { - s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Attrs(attrs ...interface{}) *search { - s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Assign(attrs ...interface{}) *search { - s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...)) - return s -} - -func (s *search) Order(value interface{}, reorder ...bool) *search { - if len(reorder) > 0 && reorder[0] { - s.orders = []interface{}{} - } - - if value != nil && value != "" { - s.orders = append(s.orders, value) - } - return s -} - -func (s *search) Select(query interface{}, args ...interface{}) *search { - s.selects = map[string]interface{}{"query": query, "args": args} - return s -} - -func (s *search) Omit(columns ...string) *search { - s.omits = columns - return s -} - -func (s *search) Limit(limit interface{}) *search { - s.limit = limit - return s -} - -func (s *search) Offset(offset interface{}) *search { - s.offset = offset - return s -} - -func (s *search) Group(query string) *search { - s.group = s.getInterfaceAsSQL(query) - return s -} - -func (s *search) Having(query interface{}, values ...interface{}) *search { - if val, ok := query.(*expr); ok { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) - } else { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) - } - return s -} - -func (s *search) Joins(query string, values ...interface{}) *search { - s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values}) - return s -} - -func (s *search) Preload(schema string, values ...interface{}) *search { - var preloads []searchPreload - for _, preload := range s.preload { - if preload.schema != schema { - preloads = append(preloads, preload) - } - } - preloads = append(preloads, searchPreload{schema, values}) - s.preload = preloads - return s -} - -func (s *search) Raw(b bool) *search { - s.raw = b - return s -} - -func (s *search) unscoped() *search { - s.Unscoped = true - return s -} - -func (s *search) Table(name string) *search { - s.tableName = name - return s -} - -func (s *search) getInterfaceAsSQL(value interface{}) (str string) { - switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - str = fmt.Sprintf("%v", value) - default: - s.db.AddError(ErrInvalidSQL) - } - - if str == "-1" { - return "" - } - return -} diff --git a/search_test.go b/search_test.go deleted file mode 100644 index 4db7ab6a..00000000 --- a/search_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package gorm - -import ( - "reflect" - "testing" -) - -func TestCloneSearch(t *testing.T) { - s := new(search) - s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age") - - s1 := s.clone() - s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email") - - if reflect.DeepEqual(s.whereConditions, s1.whereConditions) { - t.Errorf("Where should be copied") - } - - if reflect.DeepEqual(s.orders, s1.orders) { - t.Errorf("Order should be copied") - } - - if reflect.DeepEqual(s.initAttrs, s1.initAttrs) { - t.Errorf("InitAttrs should be copied") - } - - if reflect.DeepEqual(s.Select, s1.Select) { - t.Errorf("selectStr should be copied") - } -} diff --git a/soft_delete.go b/soft_delete.go new file mode 100644 index 00000000..b16041f1 --- /dev/null +++ b/soft_delete.go @@ -0,0 +1,138 @@ +package gorm + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +type DeletedAt sql.NullTime + +// Scan implements the Scanner interface. +func (n *DeletedAt) Scan(value interface{}) error { + return (*sql.NullTime)(n).Scan(value) +} + +// Value implements the driver Valuer interface. +func (n DeletedAt) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +func (n DeletedAt) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Time) + } + return json.Marshal(nil) +} + +func (n *DeletedAt) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Valid = false + return nil + } + err := json.Unmarshal(b, &n.Time) + if err == nil { + n.Valid = true + } + return err +} + +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteQueryClause{Field: f}} +} + +type SoftDeleteQueryClause struct { + Field *schema.Field +} + +func (sd SoftDeleteQueryClause) Name() string { + return "" +} + +func (sd SoftDeleteQueryClause) Build(clause.Builder) { +} + +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + if c, ok := stmt.Clauses["WHERE"]; ok { + if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + for _, expr := range where.Exprs { + if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + c.Expression = where + stmt.Clauses["WHERE"] = c + break + } + } + } + } + + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.String() == "" { + curTime := stmt.DB.NowFunc() + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.SetColumn(sd.Field.DBName, curTime, true) + + if stmt.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) + column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) + column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } + } + + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { + stmt.DB.AddError(ErrMissingWhereClause) + } else { + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + } + + stmt.AddClauseIfNotExists(clause.Update{}) + stmt.Build("UPDATE", "SET", "WHERE") + } +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..a87fd212 --- /dev/null +++ b/statement.go @@ -0,0 +1,674 @@ +package gorm + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" +) + +// Statement statement +type Statement struct { + *DB + TableExpr *clause.Expr + Table string + Model interface{} + Unscoped bool + Dest interface{} + ReflectValue reflect.Value + Clauses map[string]clause.Clause + BuildClauses []string + Distinct bool + Selects []string // selected columns + Omits []string // omit columns + Joins []join + Preloads map[string][]interface{} + Settings sync.Map + ConnPool ConnPool + Schema *schema.Schema + Context context.Context + RaiseErrorOnNotFound bool + SkipHooks bool + SQL strings.Builder + Vars []interface{} + CurDestIndex int + attrs []interface{} + assigns []interface{} + scopes []func(*DB) *DB +} + +type join struct { + Name string + Conds []interface{} +} + +// StatementModifier statement modifier interface +type StatementModifier interface { + ModifyStatement(*Statement) +} + +// Write write string +func (stmt *Statement) WriteString(str string) (int, error) { + return stmt.SQL.WriteString(str) +} + +// Write write string +func (stmt *Statement) WriteByte(c byte) error { + return stmt.SQL.WriteByte(c) +} + +// WriteQuoted write quoted value +func (stmt *Statement) WriteQuoted(value interface{}) { + stmt.QuoteTo(&stmt.SQL, value) +} + +// QuoteTo write quoted value to writer +func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { + switch v := field.(type) { + case clause.Table: + if v.Name == clause.CurrentTable { + if stmt.TableExpr != nil { + stmt.TableExpr.Build(stmt) + } else { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } + } else if v.Raw { + writer.WriteString(v.Name) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Name) + } + + if v.Alias != "" { + writer.WriteByte(' ') + stmt.DB.Dialector.QuoteTo(writer, v.Alias) + } + case clause.Column: + if v.Table != "" { + if v.Table == clause.CurrentTable { + stmt.DB.Dialector.QuoteTo(writer, stmt.Table) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Table) + } + writer.WriteByte('.') + } + + if v.Name == clause.PrimaryKey { + if stmt.Schema == nil { + stmt.DB.AddError(ErrModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + } else if len(stmt.Schema.DBNames) > 0 { + stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) + } + } else if v.Raw { + writer.WriteString(v.Name) + } else { + stmt.DB.Dialector.QuoteTo(writer, v.Name) + } + + if v.Alias != "" { + writer.WriteString(" AS ") + stmt.DB.Dialector.QuoteTo(writer, v.Alias) + } + case []clause.Column: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.QuoteTo(writer, d) + } + writer.WriteByte(')') + case string: + stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx > 0 { + writer.WriteString(",") + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') + default: + stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) + } +} + +// Quote returns quoted value +func (stmt *Statement) Quote(field interface{}) string { + var builder strings.Builder + stmt.QuoteTo(&builder, field) + return builder.String() +} + +// Write write string +func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { + for idx, v := range vars { + if idx > 0 { + writer.WriteByte(',') + } + + switch v := v.(type) { + case sql.NamedArg: + stmt.Vars = append(stmt.Vars, v.Value) + case clause.Column, clause.Table: + stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) + case clause.Expr: + v.Build(stmt) + case *clause.Expr: + v.Build(stmt) + case driver.Valuer: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []byte: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + case []interface{}: + if len(v) > 0 { + writer.WriteByte('(') + stmt.AddVar(writer, v...) + writer.WriteByte(')') + } else { + writer.WriteString("(NULL)") + } + case *DB: + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + if v.Statement.SQL.Len() > 0 { + var ( + vars = subdb.Statement.Vars + sql = v.Statement.SQL.String() + ) + + subdb.Statement.Vars = make([]interface{}, 0, len(vars)) + for _, vv := range vars { + subdb.Statement.Vars = append(subdb.Statement.Vars, vv) + bindvar := strings.Builder{} + v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) + sql = strings.Replace(sql, bindvar.String(), "?", 1) + } + + subdb.Statement.SQL.Reset() + subdb.Statement.Vars = stmt.Vars + if strings.Contains(sql, "@") { + clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } else { + clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) + } + } else { + subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) + subdb.callbacks.Query().Execute(subdb) + } + + writer.WriteString(subdb.Statement.SQL.String()) + stmt.Vars = subdb.Statement.Vars + default: + switch rv := reflect.ValueOf(v); rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + writer.WriteString("(NULL)") + } else { + writer.WriteByte('(') + for i := 0; i < rv.Len(); i++ { + if i > 0 { + writer.WriteByte(',') + } + stmt.AddVar(writer, rv.Index(i).Interface()) + } + writer.WriteByte(')') + } + default: + stmt.Vars = append(stmt.Vars, v) + stmt.DB.Dialector.BindVarTo(writer, stmt, v) + } + } + } +} + +// AddClause add clause +func (stmt *Statement) AddClause(v clause.Interface) { + if optimizer, ok := v.(StatementModifier); ok { + optimizer.ModifyStatement(stmt) + } else { + name := v.Name() + c := stmt.Clauses[name] + c.Name = name + v.MergeClause(&c) + stmt.Clauses[name] = c + } +} + +// AddClauseIfNotExists add clause if not exists +func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { + if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { + stmt.AddClause(v) + } +} + +// BuildCondition build condition +func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { + if s, ok := query.(string); ok { + // if it is a number, then treats it as primary key + if _, err := strconv.Atoi(s); err != nil { + if s == "" && len(args) == 0 { + return nil + } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + // looks like a where condition + return []clause.Expression{clause.Expr{SQL: s, Vars: args}} + } else if len(args) > 0 && strings.Contains(s, "@") { + // looks like a named query + return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} + } else if len(args) == 1 { + return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} + } + } + } + + conds := make([]clause.Expression, 0, 4) + args = append([]interface{}{query}, args...) + for idx, arg := range args { + if valuer, ok := arg.(driver.Valuer); ok { + arg, _ = valuer.Value() + } + + switch v := arg.(type) { + case clause.Expression: + conds = append(conds, v) + case *DB: + if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + if len(where.Exprs) == 1 { + if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { + where.Exprs[0] = clause.AndConditions(orConds) + } + } + conds = append(conds, clause.And(where.Exprs...)) + } else if cs.Expression != nil { + conds = append(conds, cs.Expression) + } + } + case map[interface{}]interface{}: + for i, j := range v { + conds = append(conds, clause.Eq{Column: i, Value: j}) + } + case map[string]string: + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + case map[string]interface{}: + var keys = make([]string, 0, len(v)) + for i := range v { + keys = append(keys, i) + } + sort.Strings(keys) + + for _, key := range keys { + reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + conds = append(conds, clause.IN{Column: key, Values: values}) + } + default: + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } + } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(arg)) + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + + if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + selectedColumns := map[string]bool{} + if idx == 0 { + for _, v := range args[1:] { + if vs, ok := v.(string); ok { + selectedColumns[vs] = true + } + } + } + restricted := len(selectedColumns) != 0 + + switch reflectValue.Kind() { + case reflect.Struct: + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for _, field := range s.Fields { + selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + if selected || (!restricted && field.Readable) { + if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + if field.DBName != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + } + } + } + } + } + } + + if restricted { + break + } + } else if !reflectValue.IsValid() { + stmt.AddError(ErrInvalidData) + } else if len(conds) == 0 { + if len(args) == 1 { + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + // optimize reflect value length + valueLen := reflectValue.Len() + values := make([]interface{}, valueLen) + for i := 0; i < valueLen; i++ { + values[i] = reflectValue.Index(i).Interface() + } + + if len(values) > 0 { + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + } + return conds + } + } + + conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + } + } + } + + return conds +} + +// Build build sql with clauses names +func (stmt *Statement) Build(clauses ...string) { + var firstClauseWritten bool + + for _, name := range clauses { + if c, ok := stmt.Clauses[name]; ok { + if firstClauseWritten { + stmt.WriteByte(' ') + } + + firstClauseWritten = true + if b, ok := stmt.DB.ClauseBuilders[name]; ok { + b(c, stmt) + } else { + c.Build(stmt) + } + } + } +} + +func (stmt *Statement) Parse(value interface{}) (err error) { + if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} + stmt.Table = tables[1] + return + } + + stmt.Table = stmt.Schema.Table + } + return err +} + +func (stmt *Statement) clone() *Statement { + newStmt := &Statement{ + TableExpr: stmt.TableExpr, + Table: stmt.Table, + Model: stmt.Model, + Unscoped: stmt.Unscoped, + Dest: stmt.Dest, + ReflectValue: stmt.ReflectValue, + Clauses: map[string]clause.Clause{}, + Distinct: stmt.Distinct, + Selects: stmt.Selects, + Omits: stmt.Omits, + Preloads: map[string][]interface{}{}, + ConnPool: stmt.ConnPool, + Schema: stmt.Schema, + Context: stmt.Context, + RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, + SkipHooks: stmt.SkipHooks, + } + + if stmt.SQL.Len() > 0 { + newStmt.SQL.WriteString(stmt.SQL.String()) + newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) + newStmt.Vars = append(newStmt.Vars, stmt.Vars...) + } + + for k, c := range stmt.Clauses { + newStmt.Clauses[k] = c + } + + for k, p := range stmt.Preloads { + newStmt.Preloads[k] = p + } + + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) + } + + if len(stmt.scopes) > 0 { + newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) + copy(newStmt.scopes, stmt.scopes) + } + + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + + return newStmt +} + +// Helpers +// SetColumn set column's value +// stmt.SetColumn("Name", "jinzhu") // Hooks Method +// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method +func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + v[name] = value + } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { + for _, m := range v { + m[name] = value + } + } else if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + if stmt.ReflectValue != destValue { + if !destValue.CanAddr() { + destValueCanAddr := reflect.New(destValue.Type()) + destValueCanAddr.Elem().Set(destValue) + stmt.Dest = destValueCanAddr.Interface() + destValue = destValueCanAddr.Elem() + } + + switch destValue.Kind() { + case reflect.Struct: + field.Set(destValue, value) + default: + stmt.AddError(ErrInvalidData) + } + } + + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + + switch stmt.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if len(fromCallbacks) > 0 { + for i := 0; i < stmt.ReflectValue.Len(); i++ { + field.Set(stmt.ReflectValue.Index(i), value) + } + } else { + field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) + } + case reflect.Struct: + field.Set(stmt.ReflectValue, value) + } + } else { + stmt.AddError(ErrInvalidField) + } + } else { + stmt.AddError(ErrInvalidField) + } +} + +// Changed check model changed or not when updating +func (stmt *Statement) Changed(fields ...string) bool { + modelValue := stmt.ReflectValue + switch modelValue.Kind() { + case reflect.Slice, reflect.Array: + modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) + } + + selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) + changed := func(field *schema.Field) bool { + fieldValue, _ := field.ValueOf(modelValue) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := stmt.Dest.(map[string]interface{}); ok { + if fv, ok := v[field.Name]; ok { + return !utils.AssertEqual(fv, fieldValue) + } else if fv, ok := v[field.DBName]; ok { + return !utils.AssertEqual(fv, fieldValue) + } + } else { + destValue := reflect.ValueOf(stmt.Dest) + for destValue.Kind() == reflect.Ptr { + destValue = destValue.Elem() + } + + changedValue, zero := field.ValueOf(destValue) + return !zero && !utils.AssertEqual(changedValue, fieldValue) + } + } + return false + } + + if len(fields) == 0 { + for _, field := range stmt.Schema.FieldsByDBName { + if changed(field) { + return true + } + } + } else { + for _, name := range fields { + if field := stmt.Schema.LookUpField(name); field != nil { + if changed(field) { + return true + } + } + } + } + + return false +} + +// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { + results := map[string]bool{} + notRestricted := false + + // select columns + for _, column := range stmt.Selects { + if stmt.Schema == nil { + results[column] = true + } else if column == "*" { + notRestricted = true + for _, dbName := range stmt.Schema.DBNames { + results[dbName] = true + } + } else if column == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = true + } + } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + results[field.DBName] = true + } else { + results[column] = true + } + } + + // omit columns + for _, omit := range stmt.Omits { + if stmt.Schema == nil { + results[omit] = false + } else if omit == clause.Associations { + for _, rel := range stmt.Schema.Relationships.Relations { + results[rel.Name] = false + } + } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + results[field.DBName] = false + } else { + results[omit] = false + } + } + + if stmt.Schema != nil { + for _, field := range stmt.Schema.FieldsByName { + name := field.DBName + if name == "" { + name = field.Name + } + + if requireCreate && !field.Creatable { + results[name] = false + } else if requireUpdate && !field.Updatable { + results[name] = false + } + } + } + + return results, !notRestricted && len(stmt.Selects) > 0 +} diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..03ad81dc --- /dev/null +++ b/statement_test.go @@ -0,0 +1,36 @@ +package gorm + +import ( + "fmt" + "reflect" + "testing" + + "gorm.io/gorm/clause" +) + +func TestWhereCloneCorruption(t *testing.T) { + for whereCount := 1; whereCount <= 8; whereCount++ { + t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { + s := new(Statement) + for w := 0; w < whereCount; w++ { + s = s.clone() + s.AddClause(clause.Where{ + Exprs: s.BuildCondition(fmt.Sprintf("where%d", w)), + }) + } + + s1 := s.clone() + s1.AddClause(clause.Where{ + Exprs: s.BuildCondition("FINAL1"), + }) + s2 := s.clone() + s2.AddClause(clause.Where{ + Exprs: s.BuildCondition("FINAL2"), + }) + + if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { + t.Errorf("Where conditions should be different") + } + }) + } +} diff --git a/test_all.sh b/test_all.sh deleted file mode 100755 index 80b319bf..00000000 --- a/test_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -dialects=("postgres" "mysql" "mssql" "sqlite") - -for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test -done diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..08cb523c --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +go.sum diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..6ae3337f --- /dev/null +++ b/tests/README.md @@ -0,0 +1,10 @@ +# Test Guide + +```bash +cd tests +# prepare test databases +docker-compose up + +# run all tests +./tests_all.sh +``` diff --git a/tests/associations_belongs_to_test.go b/tests/associations_belongs_to_test.go new file mode 100644 index 00000000..3e4de726 --- /dev/null +++ b/tests/associations_belongs_to_test.go @@ -0,0 +1,219 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestBelongsToAssociation(t *testing.T) { + var user = *GetUser("belongs-to", Config{Company: true, Manager: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + pointerOfUser := &user2 + if err := DB.Model(&pointerOfUser).Association("Company").Find(&user2.Company); err != nil { + t.Errorf("failed to query users, got error %#v", err) + } + user2.Manager = &User{} + DB.Model(&user2).Association("Manager").Find(user2.Manager) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Company", 1, "") + AssertAssociationCount(t, user, "Manager", 1, "") + + // Append + var company = Company{Name: "company-belongs-to-append"} + var manager = GetUser("manager-belongs-to-append", Config{}) + + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if company.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + if manager.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company + user.Manager = manager + user.CompanyID = &company.ID + user.ManagerID = &manager.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") + AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") + + // Replace + var company2 = Company{Name: "company-belongs-to-replace"} + var manager2 = GetUser("manager-belongs-to-replace", Config{}) + + if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { + t.Fatalf("Error happened when replace Company, got %v", err) + } + + if company2.ID == 0 { + t.Fatalf("Company's ID should be created") + } + + if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { + t.Fatalf("Error happened when replace Manager, got %v", err) + } + + if manager2.ID == 0 { + t.Fatalf("Manager's ID should be created") + } + + user.Company = company2 + user.Manager = manager2 + user.CompanyID = &company2.ID + user.ManagerID = &manager2.ID + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") + AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { + t.Fatalf("Error happened when delete Company, got %v", err) + } + AssertAssociationCount(t, user2, "Company", 0, "after delete") + + if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { + t.Fatalf("Error happened when delete Manager, got %v", err) + } + AssertAssociationCount(t, user2, "Manager", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { + t.Fatalf("Error happened when append Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { + t.Fatalf("Error happened when append Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 1, "after prepare data") + AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Company").Clear(); err != nil { + t.Errorf("Error happened when clear Company, got %v", err) + } + + if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { + t.Errorf("Error happened when clear Manager, got %v", err) + } + + AssertAssociationCount(t, user2, "Company", 0, "after clear") + AssertAssociationCount(t, user2, "Manager", 0, "after clear") +} + +func TestBelongsToAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), + *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), + *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), + } + + DB.Create(&users) + + AssertAssociationCount(t, users, "Company", 3, "") + AssertAssociationCount(t, users, "Manager", 2, "") + + // Find + var companies []Company + if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { + t.Errorf("companies count should be %v, but got %v", 3, len(companies)) + } + + var managers []User + if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { + t.Errorf("managers count should be %v, but got %v", 2, len(managers)) + } + + // Append + DB.Model(&users).Association("Company").Append( + &Company{Name: "company-slice-append-1"}, + &Company{Name: "company-slice-append-2"}, + &Company{Name: "company-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Company", 3, "After Append") + + DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-1", Config{}), + GetUser("manager-slice-belongs-to-2", Config{}), + GetUser("manager-slice-belongs-to-3", Config{}), + ) + AssertAssociationCount(t, users, "Manager", 3, "After Append") + + if err := DB.Model(&users).Association("Manager").Append( + GetUser("manager-slice-belongs-to-test-1", Config{}), + ).Error; err == nil { + t.Errorf("unmatched length when update user's manager") + } + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { + t.Errorf("no error should happened when deleting company, but got %v", err) + } + + if users[0].CompanyID != nil || users[0].Company.ID != 0 { + t.Errorf("users[0]'s company should be deleted'") + } + + AssertAssociationCount(t, users, "Company", 2, "After Delete") + + // Clear + DB.Model(&users).Association("Company").Clear() + AssertAssociationCount(t, users, "Company", 0, "After Clear") + + DB.Model(&users).Association("Manager").Clear() + AssertAssociationCount(t, users, "Manager", 0, "After Clear") + + // shared company + company := Company{Name: "shared"} + if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { + t.Errorf("Error happened when append company to user, got %v", err) + } + + if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { + t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) + } + + DB.Model(&users[0]).Association("Company").Delete(&company) + AssertAssociationCount(t, users[0], "Company", 0, "After Delete") + AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") +} diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go new file mode 100644 index 00000000..173e9231 --- /dev/null +++ b/tests/associations_has_many_test.go @@ -0,0 +1,473 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Pets: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Pets").Find(&user2.Pets) + CheckUser(t, user2, user) + + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + + // Count + AssertAssociationCount(t, user, "Pets", 2, "") + + // Append + var pet = Pet{Name: "pet-has-many-append"} + + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") + + var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + for _, pet := range pets2 { + var pet = pet + if pet.ID == 0 { + t.Fatalf("Pet's ID should be created") + } + + user.Pets = append(user.Pets, &pet) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") + + // Replace + var pet2 = Pet{Name: "pet-has-many-replace"} + + if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { + t.Fatalf("Error happened when append pet, got %v", err) + } + + if pet2.ID == 0 { + t.Fatalf("pet2's ID should be created") + } + + user.Pets = []*Pet{&pet2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { + t.Fatalf("Error happened when delete pet, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { + t.Fatalf("Error happened when delete Pets, got %v", err) + } + AssertAssociationCount(t, user2, "Pets", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { + t.Fatalf("Error happened when append Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { + t.Errorf("Error happened when clear Pets, got %v", err) + } + + AssertAssociationCount(t, user2, "Pets", 0, "after clear") +} + +func TestSingleTableHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Team: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Team").Find(&user2.Team) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Team", 2, "") + + // Append + var team = *GetUser("team", Config{}) + + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 3, "AfterAppend") + + var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} + + if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + for _, team := range teams { + var team = team + if team.ID == 0 { + t.Fatalf("Team's ID should be created") + } + + user.Team = append(user.Team, team) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") + + // Replace + var team2 = *GetUser("team-replace", Config{}) + + if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { + t.Fatalf("Error happened when append team, got %v", err) + } + + if team2.ID == 0 { + t.Fatalf("team2's ID should be created") + } + + user.Team = []User{team2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { + t.Fatalf("Error happened when delete Team, got %v", err) + } + AssertAssociationCount(t, user2, "Team", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { + t.Fatalf("Error happened when append Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Team").Clear(); err != nil { + t.Errorf("Error happened when clear Team, got %v", err) + } + + AssertAssociationCount(t, user2, "Team", 0, "after clear") +} + +func TestHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Pets: 2}), + *GetUser("slice-hasmany-2", Config{Pets: 0}), + *GetUser("slice-hasmany-3", Config{Pets: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Pets", 6, "") + + // Find + var pets []Pet + if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { + t.Errorf("pets count should be %v, but got %v", 6, len(pets)) + } + + // Append + DB.Model(&users).Association("Pets").Append( + &Pet{Name: "pet-slice-append-1"}, + []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &Pet{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Pets").Replace( + []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &Pet{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Pets", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 4, "after delete") + + if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Pets", 2, "after delete") + + // Clear + DB.Model(&users).Association("Pets").Clear() + AssertAssociationCount(t, users, "Pets", 0, "After Clear") +} + +func TestSingleTableHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Team: 2}), + *GetUser("slice-hasmany-2", Config{Team: 0}), + *GetUser("slice-hasmany-3", Config{Team: 4}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + DB.Model(&users).Association("Team").Append( + &User{Name: "pet-slice-append-1"}, + []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, + &User{Name: "pet-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Team", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Team").Replace( + []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, + []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, + &User{Name: "pet-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Team", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happened when deleting pet, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} + +func TestPolymorphicHasManyAssociation(t *testing.T) { + var user = *GetUser("hasmany", Config{Toys: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Toys").Find(&user2.Toys) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Toys", 2, "") + + // Append + var toy = Toy{Name: "toy-has-many-append"} + + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") + + var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} + + if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + for _, toy := range toys { + var toy = toy + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + user.Toys = append(user.Toys, toy) + } + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") + + // Replace + var toy2 = Toy{Name: "toy-has-many-replace"} + + if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + user.Toys = []Toy{toy2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toys, got %v", err) + } + AssertAssociationCount(t, user2, "Toys", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { + t.Errorf("Error happened when clear Toys, got %v", err) + } + + AssertAssociationCount(t, user2, "Toys", 0, "after clear") +} + +func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasmany-1", Config{Toys: 2}), + *GetUser("slice-hasmany-2", Config{Toys: 0}), + *GetUser("slice-hasmany-3", Config{Toys: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Toys", 6, "") + + // Find + var toys []Toy + if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { + t.Errorf("toys count should be %v, but got %v", 6, len(toys)) + } + + // Append + DB.Model(&users).Association("Toys").Append( + &Toy{Name: "toy-slice-append-1"}, + []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 10, "After Append") + + // Replace -> same as append + DB.Model(&users).Association("Toys").Replace( + []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, + []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, + &Toy{Name: "toy-slice-replace-3"}, + ) + + AssertAssociationCount(t, users, "Toys", 5, "After Append") + + // Delete + if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 4, "after delete") + + if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, users, "Toys", 2, "after delete") + + // Clear + DB.Model(&users).Association("Toys").Clear() + AssertAssociationCount(t, users, "Toys", 0, "After Clear") +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go new file mode 100644 index 00000000..a4fc8c4f --- /dev/null +++ b/tests/associations_has_one_test.go @@ -0,0 +1,257 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestHasOneAssociation(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Account").Find(&user2.Account) + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Account", 1, "") + + // Append + var account = Account{Number: "account-has-one-append"} + + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + if account.ID == 0 { + t.Fatalf("Account's ID should be created") + } + + user.Account = account + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Account", 1, "AfterAppend") + + // Replace + var account2 = Account{Number: "account-has-one-replace"} + + if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + if account2.ID == 0 { + t.Fatalf("account2's ID should be created") + } + + user.Account = account2 + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { + t.Fatalf("Error happened when delete account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { + t.Fatalf("Error happened when delete Account, got %v", err) + } + AssertAssociationCount(t, user2, "Account", 0, "after delete") + + // Prepare Data for Clear + account = Account{Number: "account-has-one-append"} + if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { + t.Fatalf("Error happened when append Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Account").Clear(); err != nil { + t.Errorf("Error happened when clear Account, got %v", err) + } + + AssertAssociationCount(t, user2, "Account", 0, "after clear") +} + +func TestHasOneAssociationWithSelect(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + DB.Omit("Account.Number").Create(&user) + + AssertAssociationCount(t, user, "Account", 1, "") + + var account Account + DB.Model(&user).Association("Account").Find(&account) + if account.Number != "" { + t.Errorf("account's number should not be saved") + } +} + +func TestHasOneAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-hasone-1", Config{Account: true}), + *GetUser("slice-hasone-2", Config{Account: false}), + *GetUser("slice-hasone-3", Config{Account: true}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Account", 2, "") + + // Find + var accounts []Account + if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { + t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) + } + + // Append + DB.Model(&users).Association("Account").Append( + &Account{Number: "account-slice-append-1"}, + &Account{Number: "account-slice-append-2"}, + &Account{Number: "account-slice-append-3"}, + ) + + AssertAssociationCount(t, users, "Account", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { + t.Errorf("no error should happened when deleting account, but got %v", err) + } + + AssertAssociationCount(t, users, "Account", 2, "after delete") + + // Clear + DB.Model(&users).Association("Account").Clear() + AssertAssociationCount(t, users, "Account", 0, "After Clear") +} + +func TestPolymorphicHasOneAssociation(t *testing.T) { + var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + // Find + var pet2 Pet + DB.Find(&pet2, "id = ?", pet.ID) + DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) + CheckPet(t, pet2, pet) + + // Count + AssertAssociationCount(t, pet, "Toy", 1, "") + + // Append + var toy = Toy{Name: "toy-has-one-append"} + + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append toy, got %v", err) + } + + if toy.ID == 0 { + t.Fatalf("Toy's ID should be created") + } + + pet.Toy = toy + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") + + // Replace + var toy2 = Toy{Name: "toy-has-one-replace"} + + if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + if toy2.ID == 0 { + t.Fatalf("toy2's ID should be created") + } + + pet.Toy = toy2 + CheckPet(t, pet2, pet) + + AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") + + // Delete + if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { + t.Fatalf("Error happened when delete toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") + + if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { + t.Fatalf("Error happened when delete Toy, got %v", err) + } + AssertAssociationCount(t, pet2, "Toy", 0, "after delete") + + // Prepare Data for Clear + toy = Toy{Name: "toy-has-one-append"} + if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { + t.Fatalf("Error happened when append Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") + + // Clear + if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { + t.Errorf("Error happened when clear Toy, got %v", err) + } + + AssertAssociationCount(t, pet2, "Toy", 0, "after clear") +} + +func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { + var pets = []Pet{ + {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, + {Name: "hasone-2", Toy: Toy{}}, + {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, + } + + DB.Create(&pets) + + // Count + AssertAssociationCount(t, pets, "Toy", 2, "") + + // Find + var toys []Toy + if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { + t.Errorf("toys count should be %v, but got %v", 3, len(toys)) + } + + // Append + DB.Model(&pets).Association("Toy").Append( + &Toy{Name: "toy-slice-append-1"}, + &Toy{Name: "toy-slice-append-2"}, + &Toy{Name: "toy-slice-append-3"}, + ) + + AssertAssociationCount(t, pets, "Toy", 3, "After Append") + + // Replace -> same as append + + // Delete + if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { + t.Errorf("no error should happened when deleting toy, but got %v", err) + } + + AssertAssociationCount(t, pets, "Toy", 2, "after delete") + + // Clear + DB.Model(&pets).Association("Toy").Clear() + AssertAssociationCount(t, pets, "Toy", 0, "After Clear") +} diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go new file mode 100644 index 00000000..739d1682 --- /dev/null +++ b/tests/associations_many2many_test.go @@ -0,0 +1,326 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Languages: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Languages").Find(&user2.Languages) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Languages", 2, "") + + // Append + var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"} + DB.Create(&language) + + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Languages = append(user.Languages, language) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") + + var languages = []Language{ + {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, + {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, + } + DB.Create(&languages) + + if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = append(user.Languages, languages...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") + + // Replace + var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} + DB.Create(&language2) + + if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { + t.Fatalf("Error happened when append language, got %v", err) + } + + user.Languages = []Language{language2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { + t.Fatalf("Error happened when delete language, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { + t.Fatalf("Error happened when delete Languages, got %v", err) + } + AssertAssociationCount(t, user2, "Languages", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { + t.Fatalf("Error happened when append Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { + t.Errorf("Error happened when clear Languages, got %v", err) + } + + AssertAssociationCount(t, user2, "Languages", 0, "after clear") +} + +func TestMany2ManyOmitAssociations(t *testing.T) { + var user = *GetUser("many2many_omit_associations", Config{Languages: 2}) + + if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { + t.Fatalf("should raise error when create users without languages reference") + } + + if err := DB.Create(&user.Languages).Error; err != nil { + t.Fatalf("no error should happen when create languages, but got %v", err) + } + + if err := DB.Omit("Languages.*").Create(&user).Error; err != nil { + t.Fatalf("no error should happen when create user when languages exists, but got %v", err) + } + + // Find + var languages []Language + if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { + t.Errorf("languages count should be %v, but got %v", 2, len(languages)) + } + + var newLang = Language{Code: "omitmany2many", Name: "omitmany2many"} + if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { + t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) + } +} + +func TestMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Languages: 2}), + *GetUser("slice-many2many-2", Config{Languages: 0}), + *GetUser("slice-many2many-3", Config{Languages: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Languages", 6, "") + + // Find + var languages []Language + if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { + t.Errorf("languages count should be %v, but got %v", 6, len(languages)) + } + + // Append + var languages1 = []Language{ + {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, + } + var languages2 = []Language{} + var languages3 = []Language{ + {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, + {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, + } + DB.Create(&languages1) + DB.Create(&languages3) + + DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) + + AssertAssociationCount(t, users, "Languages", 9, "After Append") + + languages2_1 := []*Language{ + {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, + {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, + } + languages2_2 := []*Language{ + {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, + {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, + } + languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} + DB.Create(&languages2_1) + DB.Create(&languages2_2) + DB.Create(&languages2_3) + + // Replace + DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) + + AssertAssociationCount(t, users, "Languages", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { + t.Errorf("no error should happened when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 4, "after delete") + + if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { + t.Errorf("no error should happened when deleting language, but got %v", err) + } + + AssertAssociationCount(t, users, "Languages", 2, "after delete") + + // Clear + DB.Model(&users).Association("Languages").Clear() + AssertAssociationCount(t, users, "Languages", 0, "After Clear") +} + +func TestSingleTableMany2ManyAssociation(t *testing.T) { + var user = *GetUser("many2many", Config{Friends: 2}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + // Find + var user2 User + DB.Find(&user2, "id = ?", user.ID) + DB.Model(&user2).Association("Friends").Find(&user2.Friends) + + CheckUser(t, user2, user) + + // Count + AssertAssociationCount(t, user, "Friends", 2, "") + + // Append + var friend = *GetUser("friend", Config{}) + + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append account, got %v", err) + } + + user.Friends = append(user.Friends, &friend) + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") + + var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} + + if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = append(user.Friends, friends...) + + CheckUser(t, user2, user) + + AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") + + // Replace + var friend2 = *GetUser("friend-replace-2", Config{}) + + if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { + t.Fatalf("Error happened when append friend, got %v", err) + } + + user.Friends = []*User{&friend2} + CheckUser(t, user2, user) + + AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") + + // Delete + if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { + t.Fatalf("Error happened when delete friend, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") + + if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { + t.Fatalf("Error happened when delete Friends, got %v", err) + } + AssertAssociationCount(t, user2, "Friends", 0, "after delete") + + // Prepare Data for Clear + if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { + t.Fatalf("Error happened when append Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") + + // Clear + if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { + t.Errorf("Error happened when clear Friends, got %v", err) + } + + AssertAssociationCount(t, user2, "Friends", 0, "after clear") +} + +func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice-many2many-1", Config{Team: 2}), + *GetUser("slice-many2many-2", Config{Team: 0}), + *GetUser("slice-many2many-3", Config{Team: 4}), + } + + DB.Create(&users) + + // Count + AssertAssociationCount(t, users, "Team", 6, "") + + // Find + var teams []User + if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { + t.Errorf("teams count should be %v, but got %v", 6, len(teams)) + } + + // Append + var teams1 = []User{*GetUser("friend-append-1", Config{})} + var teams2 = []User{} + var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} + + DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) + + AssertAssociationCount(t, users, "Team", 9, "After Append") + + var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} + var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} + var teams2_3 = GetUser("friend-replace-3-1", Config{}) + + // Replace + DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) + + AssertAssociationCount(t, users, "Team", 5, "After Replace") + + // Delete + if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { + t.Errorf("no error should happened when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 4, "after delete") + + if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { + t.Errorf("no error should happened when deleting team, but got %v", err) + } + + AssertAssociationCount(t, users, "Team", 2, "after delete") + + // Clear + DB.Model(&users).Association("Team").Clear() + AssertAssociationCount(t, users, "Team", 0, "After Clear") +} diff --git a/tests/associations_test.go b/tests/associations_test.go new file mode 100644 index 00000000..f470338f --- /dev/null +++ b/tests/associations_test.go @@ -0,0 +1,178 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { + if count := DB.Model(data).Association(name).Count(); count != result { + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + + var newUser User + if user, ok := data.(User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } else if user, ok := data.(*User); ok { + DB.Find(&newUser, "id = ?", user.ID) + } + + if newUser.ID != 0 { + if count := DB.Model(&newUser).Association(name).Count(); count != result { + t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) + } + } +} + +func TestInvalidAssociation(t *testing.T) { + var user = *GetUser("invalid", Config{Company: true, Manager: true}) + if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { + t.Fatalf("should return errors for invalid association, but got nil") + } +} + +func TestAssociationNotNullClear(t *testing.T) { + type Profile struct { + gorm.Model + Number string + MemberID uint `gorm:"not null"` + } + + type Member struct { + gorm.Model + Profiles []Profile + } + + DB.Migrator().DropTable(&Member{}, &Profile{}) + + if err := DB.AutoMigrate(&Member{}, &Profile{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := &Member{ + Profiles: []Profile{{ + Number: "1", + }, { + Number: "2", + }}, + } + + if err := DB.Create(&member).Error; err != nil { + t.Fatalf("Failed to create test data, got error: %v", err) + } + + if err := DB.Model(member).Association("Profiles").Clear(); err == nil { + t.Fatalf("No error occured during clearind not null association") + } +} + +func TestForeignKeyConstraints(t *testing.T) { + type Profile struct { + ID uint + Name string + MemberID uint + } + + type Member struct { + ID uint + Refer uint `gorm:"uniqueIndex"` + Name string + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Refer: 1, Name: "foreign_key_constraints", Profile: Profile{Name: "my_profile"}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.MemberID != member.ID { + t.Fatalf("member id is not equal: expects: %v, got: %v", member.ID, profile.MemberID) + } + + member.Profile = Profile{} + DB.Model(&member).Update("Refer", 100) + + var profile2 Profile + if err := DB.First(&profile2, "id = ?", profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile2.MemberID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, profile2.MemberID) + } + + if r := DB.Delete(&member); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile2, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} + +func TestForeignKeyConstraintsBelongsTo(t *testing.T) { + type Profile struct { + ID uint + Name string + Refer uint `gorm:"uniqueIndex"` + } + + type Member struct { + ID uint + Name string + ProfileID uint + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:ProfileID;References:Refer"` + } + + DB.Migrator().DropTable(&Profile{}, &Member{}) + + if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := Member{Name: "foreign_key_constraints_belongs_to", Profile: Profile{Name: "my_profile_belongs_to", Refer: 1}} + + DB.Create(&member) + + var profile Profile + if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { + t.Fatalf("failed to find profile, got error: %v", err) + } else if profile.Refer != member.ProfileID { + t.Fatalf("member id is not equal: expects: %v, got: %v", profile.Refer, member.ProfileID) + } + + DB.Model(&profile).Update("Refer", 100) + + var member2 Member + if err := DB.First(&member2, "id = ?", member.ID).Error; err != nil { + t.Fatalf("failed to find member, got error: %v", err) + } else if member2.ProfileID != 100 { + t.Fatalf("member id is not equal: expects: %v, got: %v", 100, member2.ProfileID) + } + + if r := DB.Delete(&profile); r.Error != nil || r.RowsAffected != 1 { + t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) + } + + var result Member + if err := DB.First(&result, member.ID).Error; err == nil { + t.Fatalf("Should not find deleted member") + } + + if err := DB.First(&profile, profile.ID).Error; err == nil { + t.Fatalf("Should not find deleted profile") + } +} diff --git a/tests/benchmark_test.go b/tests/benchmark_test.go new file mode 100644 index 00000000..c6ce93a2 --- /dev/null +++ b/tests/benchmark_test.go @@ -0,0 +1,44 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func BenchmarkCreate(b *testing.B) { + var user = *GetUser("bench", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + } +} + +func BenchmarkFind(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Find(&User{}, "id = ?", user.ID) + } +} + +func BenchmarkUpdate(b *testing.B) { + var user = *GetUser("find", Config{}) + DB.Create(&user) + + for x := 0; x < b.N; x++ { + DB.Model(&user).Updates(map[string]interface{}{"Age": x}) + } +} + +func BenchmarkDelete(b *testing.B) { + var user = *GetUser("find", Config{}) + + for x := 0; x < b.N; x++ { + user.ID = 0 + DB.Create(&user) + DB.Delete(&user) + } +} diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go new file mode 100644 index 00000000..02765b8c --- /dev/null +++ b/tests/callbacks_test.go @@ -0,0 +1,170 @@ +package tests_test + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "testing" + + "gorm.io/gorm" +) + +func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { + var ( + got []string + funcs = reflect.ValueOf(v).Elem().FieldByName("fns") + ) + + for i := 0; i < funcs.Len(); i++ { + got = append(got, getFuncName(funcs.Index(i))) + } + + return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) +} + +func getFuncName(fc interface{}) string { + reflectValue, ok := fc.(reflect.Value) + if !ok { + reflectValue = reflect.ValueOf(fc) + } + + fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") + return fnames[len(fnames)-1] +} + +func c1(*gorm.DB) {} +func c2(*gorm.DB) {} +func c3(*gorm.DB) {} +func c4(*gorm.DB) {} +func c5(*gorm.DB) {} + +func TestCallbacks(t *testing.T) { + type callback struct { + name string + before string + after string + remove bool + replace bool + err string + match func(*gorm.DB) bool + h func(*gorm.DB) + } + + datas := []struct { + callbacks []callback + err string + results []string + }{ + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c4", "c5"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, + results: []string{"c1", "c2", "c3", "c5", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, + results: []string{"c1", "c5", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + results: []string{"c3", "c1", "c5", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, + err: "conflicting", + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, + results: []string{"c1", "c5", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, + results: []string{"c1", "c4", "c3"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c3", "c5", "c1", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + } + + for idx, data := range datas { + db, err := gorm.Open(nil, nil) + callbacks := db.Callback() + + for _, c := range data.callbacks { + var v interface{} = callbacks.Create() + callMethod := func(s interface{}, name string, args ...interface{}) { + var argValues []reflect.Value + for _, arg := range args { + argValues = append(argValues, reflect.ValueOf(arg)) + } + + results := reflect.ValueOf(s).MethodByName(name).Call(argValues) + if len(results) > 0 { + v = results[0].Interface() + } + } + + if c.name == "" { + c.name = getFuncName(c.h) + } + + if c.before != "" { + callMethod(v, "Before", c.before) + } + + if c.after != "" { + callMethod(v, "After", c.after) + } + + if c.match != nil { + callMethod(v, "Match", c.match) + } + + if c.remove { + callMethod(v, "Remove", c.name) + } else if c.replace { + callMethod(v, "Replace", c.name, c.h) + } else { + callMethod(v, "Register", c.name, c.h) + } + + if e, ok := v.(error); !ok || e != nil { + err = e + } + } + + if len(data.err) > 0 && err == nil { + t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) + } else if len(data.err) == 0 && err != nil { + t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) + } + + if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { + t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) + } + } +} diff --git a/tests/count_test.go b/tests/count_test.go new file mode 100644 index 00000000..0fef82f7 --- /dev/null +++ b/tests/count_test.go @@ -0,0 +1,132 @@ +package tests_test + +import ( + "fmt" + "regexp" + "sort" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCount(t *testing.T) { + var ( + user1 = *GetUser("count-1", Config{}) + user2 = *GetUser("count-2", Config{}) + user3 = *GetUser("count-3", Config{}) + users []User + count, count1, count2 int64 + ) + + DB.Save(&user1).Save(&user2).Save(&user3) + + if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { + t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + if count != int64(len(users)) { + t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) + } + + DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("multiple count in chain should works") + } + + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) + tx.Count(&count1) + tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) + if count1 != 1 || count2 != 3 { + t.Errorf("count after new session should works") + } + + var count3 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { + t.Errorf("Error happened when count with group, but got %v", err) + } + + if count3 != 2 { + t.Errorf("Should get correct count for count with group, but got %v", count3) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + result := dryDB.Table("users").Select("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Table("users").Distinct("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + var count4 int64 + if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count4) + } + + var count5 int64 + if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } + + var count6 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", + ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects := []User{User{Name: "main"}, {Name: "other"}, {Name: "other"}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count7 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", + ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count8 int64 + if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( + "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", + ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } + + expects = []User{User{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} + sort.SliceStable(users, func(i, j int) bool { + return strings.Compare(users[i].Name, users[j].Name) < 0 + }) + + AssertEqual(t, users, expects) + + var count9 int64 + if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { + fmt.Println("kdkdkdkdk") + return tx.Table("users") + }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { + t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) + } +} diff --git a/tests/create_test.go b/tests/create_test.go new file mode 100644 index 00000000..bd968ea8 --- /dev/null +++ b/tests/create_test.go @@ -0,0 +1,519 @@ +package tests_test + +import ( + "errors" + "regexp" + "testing" + "time" + + "github.com/jinzhu/now" + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestCreate(t *testing.T) { + var user = *GetUser("create", Config{}) + + if results := DB.Create(&user); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) + } + + if user.ID == 0 { + t.Errorf("user's primary key should has value after create, got : %v", user.ID) + } + + if user.CreatedAt.IsZero() { + t.Errorf("user's created at should be not zero") + } + + if user.UpdatedAt.IsZero() { + t.Errorf("user's updated at should be not zero") + } + + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } +} + +func TestCreateInBatches(t *testing.T) { + users := []User{ + *GetUser("create_in_batches_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_in_batches_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_in_batches_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_in_batches_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_in_batches_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.CreateInBatches(&users, 2) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateInBatchesWithDefaultSize(t *testing.T) { + users := []User{ + *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) + if result.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + +func TestCreateWithAssociations(t *testing.T) { + var user = *GetUser("create_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) +} + +func TestBulkCreateWithAssociations(t *testing.T) { + users := []User{ + *GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + *GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } + + if results := DB.Create(&users); results.Error != nil { + t.Fatalf("errors happened when create: %v", results.Error) + } else if results.RowsAffected != int64(len(users)) { + t.Fatalf("rows affected expects: %v, got %v", len(users), results.RowsAffected) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, user, user) + } + + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestBulkCreatePtrDataWithAssociations(t *testing.T) { + users := []*User{ + GetUser("bulk_ptr_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + GetUser("bulk_ptr_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + GetUser("bulk_ptr_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + GetUser("bulk_ptr_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + GetUser("bulk_ptr_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), + GetUser("bulk_ptr_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + CheckUser(t, *user, *user) + } + + var users2 []User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) + for idx, user := range users2 { + CheckUser(t, user, *users[idx]) + } +} + +func TestPolymorphicHasOne(t *testing.T) { + t.Run("Struct", func(t *testing.T) { + var pet = Pet{ + Name: "PolymorphicHasOne", + Toy: Toy{Name: "Toy-PolymorphicHasOne"}, + } + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckPet(t, pet, pet) + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + }) + + t.Run("Slice", func(t *testing.T) { + var pets = []Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var petIDs []uint + for _, pet := range pets { + petIDs = append(petIDs, pet.ID) + CheckPet(t, pet, pet) + } + + var pets2 []Pet + DB.Preload("Toy").Find(&pets2, "id IN ?", petIDs) + for idx, pet := range pets2 { + CheckPet(t, pet, pets[idx]) + } + }) + + t.Run("SliceOfPtr", func(t *testing.T) { + var pets = []*Pet{{ + Name: "PolymorphicHasOne-Slice-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, + }, { + Name: "PolymorphicHasOne-Slice-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, + }, { + Name: "PolymorphicHasOne-Slice-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) + + t.Run("Array", func(t *testing.T) { + var pets = [...]Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, pet, pet) + } + }) + + t.Run("ArrayPtr", func(t *testing.T) { + var pets = [...]*Pet{{ + Name: "PolymorphicHasOne-Array-1", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, + }, { + Name: "PolymorphicHasOne-Array-2", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, + }, { + Name: "PolymorphicHasOne-Array-3", + Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, + }} + + if err := DB.Create(&pets).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + for _, pet := range pets { + CheckPet(t, *pet, *pet) + } + }) +} + +func TestCreateEmptyStruct(t *testing.T) { + type EmptyStruct struct { + ID uint + } + DB.Migrator().DropTable(&EmptyStruct{}) + + if err := DB.AutoMigrate(&EmptyStruct{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + if err := DB.Create(&EmptyStruct{}).Error; err != nil { + t.Errorf("No error should happen when creating user, but got %v", err) + } +} + +func TestCreateEmptySlice(t *testing.T) { + var data = []User{} + if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } + + var sliceMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } +} + +func TestCreateInvalidSlice(t *testing.T) { + users := []*User{ + GetUser("invalid_slice_1", Config{}), + GetUser("invalid_slice_2", Config{}), + nil, + } + + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error invalid data when creating from slice that contains invalid data") + } +} + +func TestCreateWithExistingTimestamp(t *testing.T) { + user := User{Name: "CreateUserExistingTimestamp"} + curTime := now.MustParse("2016-01-01") + user.CreatedAt = curTime + user.UpdatedAt = curTime + DB.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + DB.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNowFuncOverride(t *testing.T) { + user := User{Name: "CreateUserTimestampOverride"} + curTime := now.MustParse("2016-01-01") + + NEW := DB.Session(&gorm.Session{ + NowFunc: func() time.Time { + return curTime + }, + }) + + NEW.Save(&user) + + AssertEqual(t, user.CreatedAt, curTime) + AssertEqual(t, user.UpdatedAt, curTime) + + var newUser User + NEW.First(&newUser, user.ID) + + AssertEqual(t, newUser.CreatedAt, curTime) + AssertEqual(t, newUser.UpdatedAt, curTime) +} + +func TestCreateWithNoGORMPrimaryKey(t *testing.T) { + type JoinTable struct { + UserID uint + FriendID uint + } + + DB.Migrator().DropTable(&JoinTable{}) + if err := DB.AutoMigrate(&JoinTable{}); err != nil { + t.Errorf("no error should happen when auto migrate, but got %v", err) + } + + jt := JoinTable{UserID: 1, FriendID: 2} + err := DB.Create(&jt).Error + if err != nil { + t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) + } +} + +func TestSelectWithCreate(t *testing.T) { + user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) + + var user2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) + + user.Birthday = nil + user.Pets = nil + user.Company = Company{} + user.Team = nil + user.Friends = nil + + CheckUser(t, user2, user) +} + +func TestOmitWithCreate(t *testing.T) { + user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) + + var result User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result, user.ID) + + user.Birthday = nil + user.Account = Account{} + user.Toys = nil + user.Manager = nil + + CheckUser(t, result, user) + + user2 := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Omit(clause.Associations).Create(&user2) + + var result2 User + DB.Preload(clause.Associations).First(&result2, user2.ID) + + user2.Account = Account{} + user2.Toys = nil + user2.Manager = nil + user2.Company = Company{} + user2.Pets = nil + user2.Team = nil + user2.Languages = nil + user2.Friends = nil + + CheckUser(t, result2, user2) +} + +func TestFirstOrCreateWithPrimaryKey(t *testing.T) { + company := Company{ID: 100, Name: "company100_with_primarykey"} + DB.FirstOrCreate(&company) + + if company.ID != 100 { + t.Errorf("invalid primary key after creating, got %v", company.ID) + } + + companies := []Company{ + {ID: 101, Name: "company101_with_primarykey"}, + {ID: 102, Name: "company102_with_primarykey"}, + } + DB.Create(&companies) + + if companies[0].ID != 101 || companies[1].ID != 102 { + t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) + } +} + +func TestCreateFromSubQuery(t *testing.T) { + user := User{Name: "jinzhu"} + + DB.Create(&user) + + subQuery := DB.Table("users").Where("name=?", user.Name).Select("id") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&Pet{}).Create([]map[string]interface{}{ + { + "name": "cat", + "user_id": gorm.Expr("(?)", DB.Table("(?) as tmp", subQuery).Select("@uid:=id")), + }, + { + "name": "dog", + "user_id": gorm.Expr("@uid"), + }, + }) + + if !regexp.MustCompile(`INSERT INTO .pets. \(.name.,.user_id.\) .*VALUES \(.+,\(SELECT @uid:=id FROM \(SELECT id FROM .users. WHERE name=.+\) as tmp\)\),\(.+,@uid\)`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) + } +} diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go new file mode 100644 index 00000000..7802eb11 --- /dev/null +++ b/tests/customize_field_test.go @@ -0,0 +1,192 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestCustomizeColumn(t *testing.T) { + type CustomizeColumn struct { + ID int64 `gorm:"column:mapped_id; primary_key:yes"` + Name string `gorm:"column:mapped_name"` + Date *time.Time `gorm:"column:mapped_time"` + } + + DB.Migrator().DropTable(&CustomizeColumn{}) + DB.AutoMigrate(&CustomizeColumn{}) + + expected := "foo" + now := time.Now() + cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} + + if count := DB.Create(&cc).RowsAffected; count != 1 { + t.Error("There should be one record be affected when create record") + } + + var cc1 CustomizeColumn + DB.First(&cc1, "mapped_name = ?", "foo") + + if cc1.Name != expected { + t.Errorf("Failed to query CustomizeColumn") + } + + cc.Name = "bar" + DB.Save(&cc) + + var cc2 CustomizeColumn + DB.First(&cc2, "mapped_id = ?", 666) + if cc2.Name != "bar" { + t.Errorf("Failed to query CustomizeColumn") + } +} + +func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { + // Make sure an ignored field does not interfere with another field's custom + // column name that matches the ignored field. + type CustomColumnAndIgnoredFieldClash struct { + Body string `gorm:"-"` + RawBody string `gorm:"column:body"` + } + + DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) + + if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { + t.Errorf("Should not raise error: %v", err) + } +} + +func TestCustomizeField(t *testing.T) { + type CustomizeFieldStruct struct { + gorm.Model + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` + } + + DB.Migrator().DropTable(&CustomizeFieldStruct{}) + + if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil { + t.Errorf("Failed to migrate, got error: %v", err) + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") { + t.Errorf("FieldIgnore should not be created") + } + + if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") { + t.Errorf("FieldIgnore should not be created") + } + + generateStruct := func(name string) *CustomizeFieldStruct { + return &CustomizeFieldStruct{ + Name: name, + FieldAllowCreate: name + "_allow_create", + FieldAllowUpdate: name + "_allow_update", + FieldAllowSave: name + "_allow_save", + FieldAllowSave2: name + "_allow_save2", + FieldAllowSave3: name + "_allow_save3", + FieldReadonly: name + "_allow_readonly", + FieldIgnore: name + "_allow_ignore", + } + } + + create := generateStruct("create") + DB.Create(&create) + + var result CustomizeFieldStruct + DB.Find(&result, "name = ?", "create") + + AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2") + + if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" { + t.Fatalf("invalid result: %#v", result) + } + + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { + t.Fatalf("invalid create/update unix time: %#v", result) + } + + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { + t.Fatalf("invalid create/update unix milli time: %#v", result) + } + + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { + t.Fatalf("invalid create/update unix nano time: %#v", result) + } + + result.FieldAllowUpdate = "field_allow_update_updated" + result.FieldReadonly = "field_readonly_updated" + result.FieldIgnore = "field_ignore_updated" + DB.Save(&result) + + var result2 CustomizeFieldStruct + DB.Find(&result2, "name = ?", "create") + + if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" { + t.Fatalf("invalid updated result: %#v", result2) + } + + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: create.FieldReadonly, FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err == nil { + t.Fatalf("Should failed to find result") + } + + if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { + t.Fatalf("failed to update field_readonly column") + } + + if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: "readonly", FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err != nil { + t.Fatalf("Should find result") + } + + var result3 CustomizeFieldStruct + DB.Find(&result3, "name = ?", "create") + + if result3.FieldReadonly != "readonly" { + t.Fatalf("invalid updated result: %#v", result3) + } + + var result4 CustomizeFieldStruct + if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil { + t.Fatalf("failed to query with inserted field, got error %v", err) + } + + AssertEqual(t, result3, result4) + + createWithDefaultTime := generateStruct("create_with_default_time") + createWithDefaultTime.AutoUnixCreateTime = 100 + createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixMilliCreateTime = 100 + createWithDefaultTime.AutoUnixMilliUpdateTime = 100 + createWithDefaultTime.AutoUnixNanoCreateTime = 100 + createWithDefaultTime.AutoUnixNanoUpdateTime = 100 + DB.Create(&createWithDefaultTime) + + var createWithDefaultTimeResult CustomizeFieldStruct + DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) + + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) + } + + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) + } + + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) + } +} diff --git a/tests/default_value_test.go b/tests/default_value_test.go new file mode 100644 index 00000000..14a0a977 --- /dev/null +++ b/tests/default_value_test.go @@ -0,0 +1,39 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" +) + +func TestDefaultValue(t *testing.T) { + type Harumph struct { + gorm.Model + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Enabled bool `gorm:"default:true"` + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate with default value, got error: %v", err) + } + + var harumph = Harumph{Email: "hello@gorm.io"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("Failed to create data with default value, got error: %v", err) + } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled { + t.Fatalf("Failed to create data with default value, got: %+v", harumph) + } + + var result Harumph + if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { + t.Fatalf("Failed to find created data, got error: %v", err) + } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled { + t.Fatalf("Failed to find created data with default data, got %+v", result) + } +} diff --git a/tests/delete_test.go b/tests/delete_test.go new file mode 100644 index 00000000..abe85b0e --- /dev/null +++ b/tests/delete_test.go @@ -0,0 +1,207 @@ +package tests_test + +import ( + "errors" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestDelete(t *testing.T) { + var users = []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("user's primary key should has value after create, got : %v", user.ID) + } + } + + if err := DB.Delete(&users[1]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + var result User + if err := DB.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } + + for _, user := range []User{users[0], users[2]} { + result = User{} + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + for _, user := range []User{users[0], users[2]} { + result = User{} + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("no error should returns when query %v, but got %v", user.ID, err) + } + } + + if err := DB.Delete(&users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Delete(&User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } +} + +func TestDeleteWithTable(t *testing.T) { + type UserWithDelete struct { + gorm.Model + Name string + } + + DB.Table("deleted_users").Migrator().DropTable(UserWithDelete{}) + DB.Table("deleted_users").AutoMigrate(UserWithDelete{}) + + user := UserWithDelete{Name: "delete1"} + DB.Table("deleted_users").Create(&user) + + var result UserWithDelete + if err := DB.Table("deleted_users").First(&result).Error; err != nil { + t.Errorf("failed to find deleted user, got error %v", err) + } + + AssertEqual(t, result, user) + + if err := DB.Table("deleted_users").Delete(&result).Error; err != nil { + t.Errorf("failed to delete user, got error %v", err) + } + + var result2 UserWithDelete + if err := DB.Table("deleted_users").First(&result2, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } + + var result3 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result3, user.ID).Error; err != nil { + t.Fatalf("failed to find record, got error %v", err) + } + + if err := DB.Table("deleted_users").Unscoped().Delete(&result).Error; err != nil { + t.Errorf("failed to delete user with unscoped, got error %v", err) + } + + var result4 UserWithDelete + if err := DB.Table("deleted_users").Unscoped().First(&result4, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should raise record not found error, but got error %v", err) + } +} + +func TestInlineCondDelete(t *testing.T) { + user1 := *GetUser("inline_delete_1", Config{}) + user2 := *GetUser("inline_delete_2", Config{}) + DB.Save(&user1).Save(&user2) + + if DB.Delete(&User{}, user1.ID).Error != nil { + t.Errorf("No error should happen when delete a record") + } else if err := DB.Where("name = ?", user1.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("User can't be found after delete") + } + + if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Errorf("No error should happen when delete a record, err=%s", err) + } else if err := DB.Where("name = ?", user2.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("User can't be found after delete") + } +} + +func TestBlockGlobalDelete(t *testing.T) { + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } +} + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteAssociationsWithUnscoped(t *testing.T) { + user := GetUser("unscoped_delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Unscoped().Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} diff --git a/tests/distinct_test.go b/tests/distinct_test.go new file mode 100644 index 00000000..29a320ff --- /dev/null +++ b/tests/distinct_test.go @@ -0,0 +1,68 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestDistinct(t *testing.T) { + var users = []User{ + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct", Config{}), + *GetUser("distinct-2", Config{}), + *GetUser("distinct-3", Config{}), + } + users[0].Age = 20 + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + var names []string + DB.Table("users").Where("name like ?", "distinct%").Order("name").Pluck("name", &names) + AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) + + var names1 []string + DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) + + AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) + + var results []User + if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { + t.Errorf("failed to query users, got error: %v", err) + } + + expects := []User{ + {Name: "distinct", Age: 20}, + {Name: "distinct", Age: 18}, + {Name: "distinct-2", Age: 18}, + {Name: "distinct-3", Age: 18}, + } + + if len(results) != 4 { + t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) + } + + for idx, expect := range expects { + AssertObjEqual(t, results[idx], expect, "Name", "Age") + } + + var count int64 + if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { + t.Errorf("failed to query users count, got error: %v, count: %v", err, count) + } + + if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { + t.Errorf("failed to query users count, got error: %v, count %v", err, count) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) + if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) + } +} diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 00000000..05e0956e --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,31 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - TZ=Asia/Shanghai + - POSTGRES_DB=gorm + - POSTGRES_USER=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go new file mode 100644 index 00000000..312a5c37 --- /dev/null +++ b/tests/embedded_struct_test.go @@ -0,0 +1,170 @@ +package tests_test + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestEmbeddedStruct(t *testing.T) { + type ReadOnly struct { + ReadOnly *bool + } + + type BasePost struct { + Id int64 + Title string + URL string + ReadOnly + } + + type Author struct { + ID string + Name string + Email string + } + + type HNPost struct { + BasePost + Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct + Upvotes int32 + } + + type EngadgetPost struct { + BasePost BasePost `gorm:"Embedded"` + Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct + ImageUrl string + } + + DB.Migrator().DropTable(&HNPost{}, &EngadgetPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}, &EngadgetPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + for _, name := range []string{"author_id", "author_name", "author_email"} { + if !DB.Migrator().HasColumn(&EngadgetPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + stmt := gorm.Statement{DB: DB} + if err := stmt.Parse(&EngadgetPost{}); err != nil { + t.Fatalf("failed to parse embedded struct") + } else if len(stmt.Schema.PrimaryFields) != 1 { + t.Errorf("should have only one primary field with embedded struct, but got %v", len(stmt.Schema.PrimaryFields)) + } + + for _, name := range []string{"user_id", "user_name", "user_email"} { + if !DB.Migrator().HasColumn(&HNPost{}, name) { + t.Errorf("should has prefixed column %v", name) + } + } + + // save embedded struct + DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) + DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) + var news HNPost + if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if news.Title != "hn_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } + + DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}}) + var egNews EngadgetPost + if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { + t.Errorf("no error should happen when query with embedded struct, but got %v", err) + } else if egNews.BasePost.Title != "engadget_news" { + t.Errorf("embedded struct's value should be scanned correctly") + } +} + +func TestEmbeddedPointerTypeStruct(t *testing.T) { + type BasePost struct { + Id int64 + Title string + URL string + } + + type HNPost struct { + *BasePost + Upvotes int32 + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) + + var hnPost HNPost + if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { + t.Errorf("No error should happen when find embedded pointer type, but got %v", err) + } + + if hnPost.Title != "embedded_pointer_type" { + t.Errorf("Should find correct value for embedded pointer type") + } +} + +type Content struct { + Content interface{} `gorm:"type:String"` +} + +func (c Content) Value() (driver.Value, error) { + return json.Marshal(c) +} + +func (c *Content) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("Embedded.Scan byte assertion failed") + } + + var value Content + if err := json.Unmarshal(b, &value); err != nil { + return err + } + + *c = value + + return nil +} + +func TestEmbeddedScanValuer(t *testing.T) { + type HNPost struct { + gorm.Model + Content + } + + DB.Migrator().DropTable(&HNPost{}) + if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + hnPost := HNPost{Content: Content{Content: "hello world"}} + + if err := DB.Create(&hnPost).Error; err != nil { + t.Errorf("Failed to create got error %v", err) + } +} + +func TestEmbeddedRelations(t *testing.T) { + type AdvancedUser struct { + User `gorm:"embedded"` + Advanced bool + } + + DB.Migrator().DropTable(&AdvancedUser{}) + + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { + if DB.Dialector.Name() != "sqlite" { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } + } +} diff --git a/tests/go.mod b/tests/go.mod new file mode 100644 index 00000000..d4b0c975 --- /dev/null +++ b/tests/go.mod @@ -0,0 +1,17 @@ +module gorm.io/gorm/tests + +go 1.14 + +require ( + github.com/google/uuid v1.1.1 + github.com/jinzhu/now v1.1.2 + github.com/lib/pq v1.6.0 + github.com/stretchr/testify v1.5.1 + gorm.io/driver/mysql v1.0.5 + gorm.io/driver/postgres v1.0.8 + gorm.io/driver/sqlite v1.1.4 + gorm.io/driver/sqlserver v1.0.7 + gorm.io/gorm v1.21.4 +) + +replace gorm.io/gorm => ../ diff --git a/tests/group_by_test.go b/tests/group_by_test.go new file mode 100644 index 00000000..96dfc547 --- /dev/null +++ b/tests/group_by_test.go @@ -0,0 +1,109 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestGroupBy(t *testing.T) { + var users = []User{{ + Name: "groupby", + Age: 10, + Birthday: Now(), + Active: true, + }, { + Name: "groupby", + Age: 20, + Birthday: Now(), + }, { + Name: "groupby", + Age: 30, + Birthday: Now(), + Active: true, + }, { + Name: "groupby1", + Age: 110, + Birthday: Now(), + }, { + Name: "groupby1", + Age: 220, + Birthday: Now(), + Active: true, + }, { + Name: "groupby1", + Age: 330, + Birthday: Now(), + Active: true, + }} + + if err := DB.Create(&users).Error; err != nil { + t.Errorf("errors happened when create: %v", err) + } + + var name string + var total int + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || total != 60 { + t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby1" || total != 660 { + t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) + } + + var result = struct { + Name string + Total int64 + }{} + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Find(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + + if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 660 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + + var active bool + if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if name != "groupby" || active != true || total != 40 { + t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) + } + + if DB.Dialector.Name() == "mysql" { + if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { + t.Errorf("no error should happen, but got %v", err) + } + + if result.Name != "groupby1" || result.Total != 330 { + t.Errorf("name should be groupby, total should be 660, but got %+v", result) + } + } +} diff --git a/tests/helper_test.go b/tests/helper_test.go new file mode 100644 index 00000000..eee34e99 --- /dev/null +++ b/tests/helper_test.go @@ -0,0 +1,230 @@ +package tests_test + +import ( + "sort" + "strconv" + "strings" + "testing" + "time" + + . "gorm.io/gorm/utils/tests" +) + +type Config struct { + Account bool + Pets int + Toys int + Company bool + Manager bool + Team int + Languages int + Friends int +} + +func GetUser(name string, config Config) *User { + var ( + birthday = time.Now().Round(time.Second) + user = User{ + Name: name, + Age: 18, + Birthday: &birthday, + } + ) + + if config.Account { + user.Account = Account{Number: name + "_account"} + } + + for i := 0; i < config.Pets; i++ { + user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) + } + + for i := 0; i < config.Toys; i++ { + user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) + } + + if config.Company { + user.Company = Company{Name: "company-" + name} + } + + if config.Manager { + user.Manager = GetUser(name+"_manager", Config{}) + } + + for i := 0; i < config.Team; i++ { + user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) + } + + for i := 0; i < config.Languages; i++ { + name := name + "_locale_" + strconv.Itoa(i+1) + language := Language{Code: name, Name: name} + user.Languages = append(user.Languages, language) + } + + for i := 0; i < config.Friends; i++ { + user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) + } + + return &user +} + +func CheckPet(t *testing.T, pet Pet, expect Pet) { + if pet.ID != 0 { + var newPet Pet + if err := DB.Where("id = ?", pet.ID).First(&newPet).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + } + } + + AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") + + AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") + + if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) + } +} + +func CheckUser(t *testing.T, user User, expect User) { + if user.ID != 0 { + var newUser User + if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } + + AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + + t.Run("Account", func(t *testing.T) { + AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + + if user.Account.Number != "" { + if !user.Account.UserID.Valid { + t.Errorf("Account's foreign key should be saved") + } else { + var account Account + DB.First(&account, "user_id = ?", user.ID) + AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") + } + } + }) + + t.Run("Pets", func(t *testing.T) { + if len(user.Pets) != len(expect.Pets) { + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + } + + sort.Slice(user.Pets, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + sort.Slice(expect.Pets, func(i, j int) bool { + return expect.Pets[i].ID > expect.Pets[j].ID + }) + + for idx, pet := range user.Pets { + if pet == nil || expect.Pets[idx] == nil { + t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) + } else { + CheckPet(t, *pet, *expect.Pets[idx]) + } + } + }) + + t.Run("Toys", func(t *testing.T) { + if len(user.Toys) != len(expect.Toys) { + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + } + + sort.Slice(user.Toys, func(i, j int) bool { + return user.Toys[i].ID > user.Toys[j].ID + }) + + sort.Slice(expect.Toys, func(i, j int) bool { + return expect.Toys[i].ID > expect.Toys[j].ID + }) + + for idx, toy := range user.Toys { + if toy.OwnerType != "users" { + t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) + } + + AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") + } + }) + + t.Run("Company", func(t *testing.T) { + AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") + }) + + t.Run("Manager", func(t *testing.T) { + if user.Manager != nil { + if user.ManagerID == nil { + t.Errorf("Manager's foreign key should be saved") + } else { + var manager User + DB.First(&manager, "id = ?", *user.ManagerID) + AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + } else if user.ManagerID != nil { + t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) + } + }) + + t.Run("Team", func(t *testing.T) { + if len(user.Team) != len(expect.Team) { + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + } + + sort.Slice(user.Team, func(i, j int) bool { + return user.Team[i].ID > user.Team[j].ID + }) + + sort.Slice(expect.Team, func(i, j int) bool { + return expect.Team[i].ID > expect.Team[j].ID + }) + + for idx, team := range user.Team { + AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) + + t.Run("Languages", func(t *testing.T) { + if len(user.Languages) != len(expect.Languages) { + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + } + + sort.Slice(user.Languages, func(i, j int) bool { + return strings.Compare(user.Languages[i].Code, user.Languages[j].Code) > 0 + }) + + sort.Slice(expect.Languages, func(i, j int) bool { + return strings.Compare(expect.Languages[i].Code, expect.Languages[j].Code) > 0 + }) + for idx, language := range user.Languages { + AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") + } + }) + + t.Run("Friends", func(t *testing.T) { + if len(user.Friends) != len(expect.Friends) { + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + } + + sort.Slice(user.Friends, func(i, j int) bool { + return user.Friends[i].ID > user.Friends[j].ID + }) + + sort.Slice(expect.Friends, func(i, j int) bool { + return expect.Friends[i].ID > expect.Friends[j].ID + }) + + for idx, friend := range user.Friends { + AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") + } + }) +} diff --git a/tests/hooks_test.go b/tests/hooks_test.go new file mode 100644 index 00000000..0e6ab2fe --- /dev/null +++ b/tests/hooks_test.go @@ -0,0 +1,495 @@ +package tests_test + +import ( + "errors" + "reflect" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type Product struct { + gorm.Model + Name string + Code string + Price float64 + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return +} + +func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return +} + +func (s *Product) BeforeSave(tx *gorm.DB) (err error) { + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return +} + +func (s *Product) AfterFind(tx *gorm.DB) (err error) { + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 + return +} + +func (s *Product) AfterCreate(tx *gorm.DB) (err error) { + return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error +} + +func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 + return +} + +func (s *Product) AfterSave(tx *gorm.DB) (err error) { + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return +} + +func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete(tx *gorm.DB) (err error) { + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return +} + +func (s *Product) GetCallTimes() []int64 { + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) + + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) + } + + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 2 { + t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } + + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Fatalf("Can't find a deleted record") + } + + beforeCallTimes := p.AfterFindCallTimes + if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { + t.Fatalf("Find don't raise error when record not found") + } + + if p.AfterFindCallTimes != beforeCallTimes { + t.Fatalf("AfterFind should not be called") + } +} + +func TestCallbacksWithErrors(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Fatalf("An error from before create callbacks happened when create with invalid value") + } + + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Fatalf("Should not save record that have errors") + } + + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Fatalf("An error from after create callbacks happened when create with invalid value") + } + + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) + + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before update callbacks happened when update with invalid value") + } + + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before save callbacks happened when update with invalid value") + } + + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Fatalf("An error from before delete callbacks happened when delete") + } + + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Fatalf("An error from before delete callbacks happened") + } + + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Fatalf("Record should be reverted if get an error in after save callback") + } + + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record should be found") + } + + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") + } +} + +type Product2 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product2) BeforeCreate(tx *gorm.DB) (err error) { + if !strings.HasSuffix(s.Name, "_clone") { + newProduft := s + newProduft.Price *= 2 + newProduft.Name += "_clone" + err = tx.Create(&newProduft).Error + } + + if s.Name == "Invalid" { + return errors.New("invalid") + } + + return nil +} + +func (s *Product2) BeforeUpdate(tx *gorm.DB) (err error) { + tx.Statement.Where("owner != ?", "admin") + return +} + +func TestUseDBInHooks(t *testing.T) { + DB.Migrator().DropTable(&Product2{}) + DB.AutoMigrate(&Product2{}) + + product := Product2{Name: "Invalid", Price: 100} + + if err := DB.Create(&product).Error; err == nil { + t.Fatalf("should returns error %v when creating product, but got nil", err) + } + + product2 := Product2{Name: "Nice", Price: 100} + + if err := DB.Create(&product2).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result Product2 + if err := DB.First(&result, "name = ?", "Nice").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + var resultClone Product2 + if err := DB.First(&resultClone, "name = ?", "Nice_clone").Error; err != nil { + t.Fatalf("Failed to find cloned product, got error: %v", err) + } + + result.Price *= 2 + result.Name += "_clone" + AssertObjEqual(t, result, resultClone, "Price", "Name") + + DB.Model(&result).Update("Price", 500) + var result2 Product2 + DB.First(&result2, "name = ?", "Nice") + + if result2.Price != 500 { + t.Errorf("Failed to update product's price, expects: %v, got %v", 500, result2.Price) + } + + product3 := Product2{Name: "Nice2", Price: 600, Owner: "admin"} + if err := DB.Create(&product3).Error; err != nil { + t.Fatalf("Failed to create product, got error: %v", err) + } + + var result3 Product2 + if err := DB.First(&result3, "name = ?", "Nice2").Error; err != nil { + t.Fatalf("Failed to query product, got error: %v", err) + } + + DB.Model(&result3).Update("Price", 800) + var result4 Product2 + DB.First(&result4, "name = ?", "Nice2") + + if result4.Price != 600 { + t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) + } +} + +type Product3 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string +} + +func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { + tx.Statement.SetColumn("Price", s.Price+100) + return nil +} + +func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { + if tx.Statement.Changed() { + tx.Statement.SetColumn("Price", s.Price+10) + } + + if tx.Statement.Changed("Code") { + s.Price += 20 + tx.Statement.SetColumn("Price", s.Price+30) + } + return nil +} + +func TestSetColumn(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + product := Product3{Name: "Product", Price: 0} + DB.Create(&product) + + if product.Price != 100 { + t.Errorf("invalid price after create, got %+v", product) + } + + DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) + + if product.Price != 150 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code not changed, price should not change + DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) + + if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, but not selected, price should not change + DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) + + if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { + t.Errorf("invalid data after update, got %+v", product) + } + + // Code changed, price should changed + DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) + + if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result Product3 + DB.First(&result, product.ID) + + AssertEqual(t, result, product) + + // Select to change Code, but nothing updated, price should not change + DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"}) + + if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Updates(Product3{Code: "L1214"}) + if product.Price != 270 || product.Code != "L1214" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) + if product.Price != 270 || product.Code != "L1215" { + t.Errorf("invalid data after update, got %+v", product) + } + + DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) + if product.Price != 270 || product.Code != "L1216" { + t.Errorf("invalid data after update, got %+v", product) + } + + var result2 Product3 + DB.First(&result2, product.ID) + + AssertEqual(t, result2, product) + + product2 := Product3{Name: "Product", Price: 0} + DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) + + if product2.Price != 0 { + t.Errorf("invalid price after create without hooks, got %+v", product2) + } +} + +func TestHooksForSlice(t *testing.T) { + DB.Migrator().DropTable(&Product3{}) + DB.AutoMigrate(&Product3{}) + + products := []*Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products) + + for idx, value := range []int64{200, 300, 400} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + DB.Model(&products).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) + } + } + + products2 := []Product3{ + {Name: "Product-1", Price: 100}, + {Name: "Product-2", Price: 200}, + {Name: "Product-3", Price: 300}, + } + + DB.Create(&products2) + + for idx, value := range []int64{200, 300, 400} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } + + DB.Model(&products2).Update("Name", "product-name") + + // will set all product's price to last product's price + 10 + for idx, value := range []int64{410, 410, 410} { + if products2[idx].Price != value { + t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) + } + } +} + +type Product4 struct { + gorm.Model + Name string + Code string + Price int64 + Owner string + Item ProductItem +} + +type ProductItem struct { + gorm.Model + Code string + Product4ID uint +} + +func (pi ProductItem) BeforeCreate(*gorm.DB) error { + if pi.Code == "invalid" { + return errors.New("invalid item") + } + return nil +} + +func TestFailedToSaveAssociationShouldRollback(t *testing.T) { + DB.Migrator().DropTable(&Product4{}, &ProductItem{}) + DB.AutoMigrate(&Product4{}, &ProductItem{}) + + product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}} + if err := DB.Create(&product).Error; err == nil { + t.Errorf("should got failed to save, but error is nil") + } + + if DB.First(&Product4{}, "name = ?", product.Name).Error == nil { + t.Errorf("should got RecordNotFound, but got nil") + } + + product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}} + if err := DB.Create(&product).Error; err != nil { + t.Errorf("should create product, but got error %v", err) + } + + if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { + t.Errorf("should find product, but got error %v", err) + } +} diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go new file mode 100644 index 00000000..084c2f2c --- /dev/null +++ b/tests/joins_table_test.go @@ -0,0 +1,116 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Person struct { + ID int + Name string + Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt +} + +type Address struct { + ID uint + Name string +} + +type PersonAddress struct { + PersonID int + AddressID int + CreatedAt time.Time + DeletedAt gorm.DeletedAt +} + +func TestOverrideJoinTable(t *testing.T) { + DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) + + if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { + t.Fatalf("Failed to setup join table for person, got error %v", err) + } + + if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { + t.Fatalf("Failed to migrate, got %v", err) + } + + address1 := Address{Name: "address 1"} + address2 := Address{Name: "address 2"} + person := Person{Name: "person", Addresses: []Address{address1, address2}} + DB.Create(&person) + + var addresses1 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) + } + + if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { + t.Fatalf("Failed to delete address, got error %v", err) + } + + if len(person.Addresses) != 1 { + t.Fatalf("Should have one address left") + } + + if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { + t.Fatalf("Should found one address") + } + + var addresses2 []Address + if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) + } + + if DB.Model(&person).Association("Addresses").Count() != 1 { + t.Fatalf("Should found one address") + } + + var addresses3 []Address + if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { + t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) + } + + if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Model(&person).Association("Addresses").Clear() + + if DB.Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("Should deleted all addresses") + } + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { + t.Fatalf("Should found soft deleted addresses with unscoped") + } + + DB.Unscoped().Model(&person).Association("Addresses").Clear() + + if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { + t.Fatalf("address should be deleted when clear with unscoped") + } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go new file mode 100644 index 00000000..46611f5f --- /dev/null +++ b/tests/joins_test.go @@ -0,0 +1,132 @@ +package tests_test + +import ( + "regexp" + "sort" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestJoins(t *testing.T) { + user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true}) + + DB.Create(&user) + + var user2 User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } + + CheckUser(t, user2, user) +} + +func TestJoinsForSlice(t *testing.T) { + users := []User{ + *GetUser("slice-joins-1", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-2", Config{Company: true, Manager: true, Account: true}), + *GetUser("slice-joins-3", Config{Company: true, Manager: true, Account: true}), + } + + DB.Create(&users) + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + if err := DB.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + t.Fatalf("Failed to load with joins, got error: %v", err) + } else if len(users2) != len(users) { + t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) + } + + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID > users2[j].ID + }) + + sort.Slice(users, func(i, j int) bool { + return users[i].ID > users[j].ID + }) + + for idx, user := range users { + CheckUser(t, user, users2[idx]) + } +} + +func TestJoinConds(t *testing.T) { + var user = *GetUser("joins-conds", Config{Account: true, Pets: 3}) + DB.Save(&user) + + var users1 []User + DB.Joins("inner join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) + if len(users1) != 3 { + t.Errorf("should find two users using left join, but got %v", len(users1)) + } + + var users2 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) + if len(users2) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) + } + + var users3 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) + if len(users3) != 1 { + t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) + } + + var users4 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) + if len(users4) != 0 { + t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) + } + + var users5 []User + db5 := DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) + if db5.Error != nil { + t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) + } + + var users6 []User + DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = @Name", user.Pets[0]).Where("users.name = ?", user.Name).First(&users6) + if len(users6) != 1 { + t.Errorf("should find one users using left join with conditions, but got %v", len(users6)) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } +} + +func TestJoinsWithSelect(t *testing.T) { + type result struct { + ID uint + PetID uint + Name string + } + + user := *GetUser("joins_with_select", Config{Pets: 2}) + DB.Save(&user) + + var results []result + + DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return results[i].PetID > results[j].PetID + }) + + sort.Slice(results, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { + t.Errorf("Should find all two pets with Join select, got %+v", results) + } +} diff --git a/tests/main_test.go b/tests/main_test.go new file mode 100644 index 00000000..5b8c7dbb --- /dev/null +++ b/tests/main_test.go @@ -0,0 +1,55 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +func TestExceptionsWithInvalidSql(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + + var columns []string + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + var count1, count2 int64 + DB.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + + DB.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("No user should not be deleted by invalid SQL") + } +} + +func TestSetAndGet(t *testing.T) { + if value, ok := DB.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else { + if value.(string) != "world" { + t.Errorf("Setted value should not be changed") + } + } + + if _, ok := DB.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go new file mode 100644 index 00000000..4da3856f --- /dev/null +++ b/tests/migrate_test.go @@ -0,0 +1,368 @@ +package tests_test + +import ( + "math/rand" + "strings" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestMigrate(t *testing.T) { + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") + + if err := DB.Migrator().DropTable(allModels...); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(allModels...); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + t.Fatalf("Failed to create table for %#v---", m) + } + } + + DB.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Table("ccc") + }).Migrator().CreateTable(&Company{}) + + if !DB.Migrator().HasTable("ccc") { + t.Errorf("failed to create table ccc") + } + + for _, indexes := range [][2]string{ + {"user_speaks", "fk_user_speaks_user"}, + {"user_speaks", "fk_user_speaks_language"}, + {"user_friends", "fk_user_friends_user"}, + {"user_friends", "fk_user_friends_friends"}, + {"accounts", "fk_users_account"}, + {"users", "fk_users_team"}, + {"users", "fk_users_company"}, + } { + if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { + t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) + } + } +} + +func TestSmartMigrateColumn(t *testing.T) { + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] + + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time `gorm:"precision:4"` + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + NameIgnoreMigration string `gorm:"size:100"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + NameIgnoreMigration string `gorm:"size:128;-:migration"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + case "name_ignore_migration": + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { + t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) + } + } + } + +} + +func TestMigrateWithComment(t *testing.T) { + type UserWithComment struct { + gorm.Model + Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"` + } + + if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + +func TestMigrateWithUniqueIndex(t *testing.T) { + type UserWithUniqueIndex struct { + ID int + Name string `gorm:"size:20;index:idx_name,unique"` + Date time.Time `gorm:"index:idx_name,unique"` + } + + DB.Migrator().DropTable(&UserWithUniqueIndex{}) + if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { + t.Fatalf("failed to migrate, got %v", err) + } + + if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { + t.Errorf("Failed to find created index") + } +} + +func TestMigrateTable(t *testing.T) { + type TableStruct struct { + gorm.Model + Name string + } + + DB.Migrator().DropTable(&TableStruct{}) + DB.AutoMigrate(&TableStruct{}) + + if !DB.Migrator().HasTable(&TableStruct{}) { + t.Fatalf("should found created table") + } + + type NewTableStruct struct { + gorm.Model + Name string + } + + if err := DB.Migrator().RenameTable(&TableStruct{}, &NewTableStruct{}); err != nil { + t.Fatalf("Failed to rename table, got error %v", err) + } + + if !DB.Migrator().HasTable("new_table_structs") { + t.Fatal("should found renamed table") + } + + DB.Migrator().DropTable("new_table_structs") + + if DB.Migrator().HasTable(&NewTableStruct{}) { + t.Fatal("should not found droped table") + } +} + +func TestMigrateIndexes(t *testing.T) { + type IndexStruct struct { + gorm.Model + Name string `gorm:"size:255;index"` + } + + DB.Migrator().DropTable(&IndexStruct{}) + DB.AutoMigrate(&IndexStruct{}) + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Got error when tried to create index: %+v", err) + } + + if !DB.Migrator().HasIndex(&IndexStruct{}, "Name") { + t.Fatalf("Failed to find index for user's name") + } + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&IndexStruct{}, "Name") { + t.Fatalf("Should not find index for user's name after delete") + } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Got error when tried to create index: %+v", err) + } + + if err := DB.Migrator().RenameIndex(&IndexStruct{}, "idx_index_structs_name", "idx_users_name_1"); err != nil { + t.Fatalf("no error should happen when rename index, but got %v", err) + } + + if !DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { + t.Fatalf("Should find index for user's name after rename") + } + + if err := DB.Migrator().DropIndex(&IndexStruct{}, "idx_users_name_1"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } + + if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { + t.Fatalf("Should not find index for user's name after delete") + } +} + +func TestMigrateColumns(t *testing.T) { + type ColumnStruct struct { + gorm.Model + Name string + } + + DB.Migrator().DropTable(&ColumnStruct{}) + + if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { + t.Errorf("Failed to migrate, got %v", err) + } + + type ColumnStruct2 struct { + gorm.Model + Name string `gorm:"size:100"` + } + + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + t.Fatalf("no error should happened when alter column, but got %v", err) + } + + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { + t.Fatalf("no error should returns for ColumnTypes") + } else { + stmt := &gorm.Statement{DB: DB} + stmt.Parse(&ColumnStruct2{}) + + for _, columnType := range columnTypes { + if columnType.Name() == "name" { + dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) + if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { + t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) + } + } + } + } + + type NewColumnStruct struct { + gorm.Model + Name string + NewName string + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Failed to find added column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { + t.Fatalf("Found deleted column") + } + + if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Failed to found renamed column") + } + + if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { + t.Fatalf("Failed to add column, got %v", err) + } + + if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { + t.Fatalf("Found deleted column") + } +} + +func TestMigrateConstraint(t *testing.T) { + if DB.Dialector.Name() == "sqlite" { + t.Skip() + } + + names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} + + for _, name := range names { + if !DB.Migrator().HasConstraint(&User{}, name) { + DB.Migrator().CreateConstraint(&User{}, name) + } + + if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { + t.Fatalf("failed to drop constraint %v, got error %v", name, err) + } + + if DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("constraint %v should been deleted", name) + } + + if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { + t.Fatalf("failed to create constraint %v, got error %v", name, err) + } + + if !DB.Migrator().HasConstraint(&User{}, name) { + t.Fatalf("failed to found constraint %v", name) + } + } +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go new file mode 100644 index 00000000..dcc90cd9 --- /dev/null +++ b/tests/multi_primary_keys_test.go @@ -0,0 +1,448 @@ +package tests_test + +import ( + "reflect" + "sort" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` + SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` + LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string + Blogs []*Blog `gorm:"many2many:blog_tags"` +} + +func compareTags(tags []Tag, contents []string) bool { + var tagContents []string + for _, tag := range tags { + tagContents = append(tagContents, tag.Value) + } + sort.Strings(tagContents) + sort.Strings(contents) + return reflect.DeepEqual(tagContents, contents) +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + stmt := gorm.Statement{DB: DB} + stmt.Parse(&Blog{}) + stmt.Schema.LookUpField("ID").Unique = true + stmt.Parse(&Tag{}) + stmt.Schema.LookUpField("ID").Unique = true + // postgers only allow unique constraint matching given keys + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) + + if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if count := DB.Model(&blog).Association("Tags").Count(); count != 3 { + t.Fatalf("Blog should has 3 tags after Append, got %v", count) + } + + var tags []Tag + DB.Model(&blog).Association("Tags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("Tags").Find(&blog1) + if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog).Association("Tags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("Tags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("Tags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("Tags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("Tags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("Tags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog).Association("Tags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("Tags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog).Association("Tags").Clear() + if DB.Model(&blog).Association("Tags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgres due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + SharedTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { + t.Fatalf("Blog should has two tags") + } + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) + if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("SharedTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + var tags []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + var blog1 Blog + DB.Preload("SharedTags").Find(&blog1) + if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("SharedTags").Append(tag4) + + DB.Model(&blog).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { + t.Fatalf("Should find 3 tags") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) + var tags2 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + DB.Model(&blog2).Association("SharedTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 2 { + t.Fatalf("Blog should has three tags after Replace") + } + + // Delete + DB.Model(&blog).Association("SharedTags").Delete(tag5) + var tags3 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags3) + if !compareTags(tags3, []string{"tag6"}) { + t.Fatalf("Should find 1 tags after Delete") + } + + if DB.Model(&blog).Association("SharedTags").Count() != 1 { + t.Fatalf("Blog should has three tags after Delete") + } + + DB.Model(&blog2).Association("SharedTags").Delete(tag3) + var tags4 []Tag + DB.Model(&blog).Association("SharedTags").Find(&tags4) + if !compareTags(tags4, []string{"tag6"}) { + t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") + } + + // Clear + DB.Model(&blog2).Association("SharedTags").Clear() + if DB.Model(&blog).Association("SharedTags").Count() != 0 { + t.Fatalf("All tags should be cleared") + } +} + +func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") + } + + if name := DB.Dialector.Name(); name == "postgres" { + t.Skip("skip postgres due to it only allow unique constraint matching given keys") + } + + DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") + if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { + t.Fatalf("Failed to auto migrate, got error: %v", err) + } + + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + LocaleTags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + DB.Save(&blog) + + blog2 := Blog{ + ID: blog.ID, + Locale: "EN", + } + DB.Create(&blog2) + + // Append + var tag3 = &Tag{Locale: "ZH", Value: "tag3"} + DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) + if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("Blog should has three tags after Append") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog should has 0 tags after ZH Blog Append") + } + + var tags []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if len(tags) != 0 { + t.Fatalf("Should find 0 tags for EN Blog") + } + + var blog1 Blog + DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) + if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Preload many2many relations") + } + + var tag4 = &Tag{Locale: "ZH", Value: "tag4"} + DB.Model(&blog2).Association("LocaleTags").Append(tag4) + + DB.Model(&blog).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("Should find 3 tags for EN Blog") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags) + if !compareTags(tags, []string{"tag4"}) { + t.Fatalf("Should find 1 tags for EN Blog") + } + + // Replace + var tag5 = &Tag{Locale: "ZH", Value: "tag5"} + var tag6 = &Tag{Locale: "ZH", Value: "tag6"} + DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) + + var tags2 []Tag + DB.Model(&blog).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + var blog11 Blog + DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) + if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { + t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") + } + + DB.Model(&blog2).Association("LocaleTags").Find(&tags2) + if !compareTags(tags2, []string{"tag5", "tag6"}) { + t.Fatalf("Should find 2 tags after Replace") + } + + var blog21 Blog + DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) + if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { + t.Fatalf("EN Blog's tags should be changed after Replace") + } + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Replace") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after Replace") + } + + // Delete + DB.Model(&blog).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { + t.Fatalf("EN Blog should has two tags after ZH Blog Delete with EN's tag") + } + + DB.Model(&blog2).Association("LocaleTags").Delete(tag5) + + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog should has three tags after EN Blog Delete with EN's tag") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { + t.Fatalf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") + } + + // Clear + DB.Model(&blog2).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 3 { + t.Fatalf("ZH Blog's tags should not be cleared when clear EN Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared when clear EN Blog's tags") + } + + DB.Model(&blog).Association("LocaleTags").Clear() + if DB.Model(&blog).Association("LocaleTags").Count() != 0 { + t.Fatalf("ZH Blog's tags should be cleared when clear ZH Blog's tags") + } + + if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { + t.Fatalf("EN Blog's tags should be cleared") + } +} + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate") + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +} diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go new file mode 100644 index 00000000..d0a6f915 --- /dev/null +++ b/tests/named_argument_test.go @@ -0,0 +1,69 @@ +package tests_test + +import ( + "database/sql" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestNamedArg(t *testing.T) { + type NamedUser struct { + gorm.Model + Name1 string + Name2 string + Name3 string + } + + DB.Migrator().DropTable(&NamedUser{}) + DB.AutoMigrate(&NamedUser{}) + + namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} + DB.Create(&namedUser) + + var result NamedUser + DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) + + AssertEqual(t, result, namedUser) + + var result2 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) + + AssertEqual(t, result2, namedUser) + + var result3 NamedUser + DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) + + AssertEqual(t, result3, namedUser) + + var result4 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result4, namedUser) + + if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + namedUser.Name1 = "jinzhu-new" + namedUser.Name2 = "jinzhu-new2" + namedUser.Name3 = "jinzhu-new" + + var result5 NamedUser + if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result5, namedUser) + + var result6 NamedUser + if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name + AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result6, namedUser) +} diff --git a/tests/named_polymorphic_test.go b/tests/named_polymorphic_test.go new file mode 100644 index 00000000..956f3a7e --- /dev/null +++ b/tests/named_polymorphic_test.go @@ -0,0 +1,147 @@ +package tests_test + +import ( + "testing" + + . "gorm.io/gorm/utils/tests" +) + +type Hamster struct { + Id int + Name string + PreferredToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_preferred"` + OtherToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_other"` +} + +func TestNamedPolymorphic(t *testing.T) { + DB.Migrator().DropTable(&Hamster{}) + DB.AutoMigrate(&Hamster{}) + + hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} + DB.Save(&hamster) + + hamster2 := Hamster{} + DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) + + if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { + t.Errorf("Hamster's preferred toy failed to preload") + } + + if hamster2.OtherToy.ID != hamster.OtherToy.ID || hamster2.OtherToy.Name != hamster.OtherToy.Name { + t.Errorf("Hamster's other toy failed to preload") + } + + // clear to omit Toy.ID in count + hamster2.PreferredToy = Toy{} + hamster2.OtherToy = Toy{} + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's preferred toy count should be 1") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy count should be 1") + } + + // Query + hamsterToy := Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.PreferredToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != hamster.OtherToy.Name { + t.Errorf("Should find has one polymorphic association") + } + + // Append + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 2" { + t.Errorf("Should update has one polymorphic association with Append") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys count should be 1 after Append") + } + + // Replace + DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ + Name: "bike 3", + }) + + DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ + Name: "treadmill 3", + }) + + hamsterToy = Toy{} + DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) + if hamsterToy.Name != "bike 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + hamsterToy = Toy{} + DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) + if hamsterToy.Name != "treadmill 3" { + t.Errorf("Should update has one polymorphic association with Replace") + } + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("hamster's toys count should be 1 after Replace") + } + + // Clear + DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ + Name: "bike 2", + }) + DB.Model(&hamster).Association("OtherToy").Append(&Toy{ + Name: "treadmill 2", + }) + + if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + if DB.Model(&hamster).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's toys should be added with Append") + } + + DB.Model(&hamster).Association("PreferredToy").Clear() + + if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { + t.Errorf("Hamster's preferred toy should be cleared with Clear") + } + + if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { + t.Errorf("Hamster's other toy should be still available") + } + + DB.Model(&hamster).Association("OtherToy").Clear() + if DB.Model(&hamster).Association("OtherToy").Count() != 0 { + t.Errorf("Hamster's other toy should be cleared with Clear") + } +} diff --git a/tests/non_std_test.go b/tests/non_std_test.go new file mode 100644 index 00000000..d3561b11 --- /dev/null +++ b/tests/non_std_test.go @@ -0,0 +1,61 @@ +package tests_test + +import ( + "testing" + "time" +) + +type Animal struct { + Counter uint64 `gorm:"primary_key:yes"` + Name string `gorm:"DEFAULT:'galeone'"` + From string //test reserved sql keyword as field name + Age *time.Time + unexported string // unexported value + CreatedAt time.Time + UpdatedAt time.Time +} + +func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { + DB.Migrator().DropTable(&Animal{}) + if err := DB.AutoMigrate(&Animal{}); err != nil { + t.Fatalf("no error should happen when migrate but got %v", err) + } + + animal := Animal{Name: "Ferdinand"} + DB.Save(&animal) + updatedAt1 := animal.UpdatedAt + + DB.Save(&animal).Update("name", "Francis") + if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be updated") + } + + var animals []Animal + DB.Find(&animals) + if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + t.Error("RowsAffected should be correct when do batch update") + } + + animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) + DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched + DB.First(&animal, animal.Counter) + if animal.Name != "galeone" { + t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) + } + + // When changing a field with a default value, the change must occur + animal.Name = "amazing horse" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "amazing horse" { + t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) + } + + // When changing a field with a default value with blank value + animal.Name = "" + DB.Save(&animal) + DB.First(&animal, animal.Counter) + if animal.Name != "" { + t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) + } +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go new file mode 100644 index 00000000..94077d1d --- /dev/null +++ b/tests/postgres_test.go @@ -0,0 +1,84 @@ +package tests_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/lib/pq" + "gorm.io/gorm" +) + +func TestPostgres(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Harumph struct { + gorm.Model + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + Things pq.StringArray `gorm:"type:text[]"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Harumph{}) + + if err := DB.AutoMigrate(&Harumph{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + harumph := Harumph{} + if err := DB.Create(&harumph).Error; err == nil { + t.Fatalf("should failed to create data, name can't be blank") + } + + harumph = Harumph{Name: "jinzhu"} + if err := DB.Create(&harumph).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + var result Harumph + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } +} + +type Post struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + Title string + Categories []*Category `gorm:"Many2Many:post_categories"` +} + +type Category struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + Title string + Posts []*Post `gorm:"Many2Many:post_categories"` +} + +func TestMany2ManyWithDefaultValueUUID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { + t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) + } + + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") + DB.AutoMigrate(&Post{}, &Category{}) + + post := Post{ + Title: "Hello World", + Categories: []*Category{ + {Title: "Coding"}, + {Title: "Golang"}, + }, + } + + if err := DB.Create(&post).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } +} diff --git a/preload_test.go b/tests/preload_suits_test.go similarity index 76% rename from preload_test.go rename to tests/preload_suits_test.go index 1b89e77b..0ef8890b 100644 --- a/preload_test.go +++ b/tests/preload_suits_test.go @@ -1,126 +1,19 @@ -package gorm_test +package tests_test import ( "database/sql" "encoding/json" - "os" "reflect" + "sort" + "sync/atomic" "testing" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) -func getPreloadUser(name string) *User { - return getPreparedUser(name, "Preload") -} - -func checkUserHasPreloadData(user User, t *testing.T) { - u := getPreloadUser(user.Name) - if user.BillingAddress.Address1 != u.BillingAddress.Address1 { - t.Error("Failed to preload user's BillingAddress") - } - - if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 { - t.Error("Failed to preload user's ShippingAddress") - } - - if user.CreditCard.Number != u.CreditCard.Number { - t.Error("Failed to preload user's CreditCard") - } - - if user.Company.Name != u.Company.Name { - t.Error("Failed to preload user's Company") - } - - if len(user.Emails) != len(u.Emails) { - t.Error("Failed to preload user's Emails") - } else { - var found int - for _, e1 := range u.Emails { - for _, e2 := range user.Emails { - if e1.Email == e2.Email { - found++ - break - } - } - } - if found != len(u.Emails) { - t.Error("Failed to preload user's email details") - } - } -} - -func TestPreload(t *testing.T) { - user1 := getPreloadUser("user1") - DB.Save(user1) - - preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("user2") - DB.Save(user2) - - user3 := getPreloadUser("user3") - DB.Save(user3) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } - - var users3 []*User - preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3) - - for _, user := range users3 { - if user.Name == user3.Name { - if len(user.Emails) != 1 { - t.Errorf("should only preload one emails for user3 when with condition") - } - } else if len(user.Emails) != 0 { - t.Errorf("should not preload any emails for other users when with condition") - } else if user.Emails == nil { - t.Errorf("should return an empty slice to indicate zero results") - } - } -} - -func TestAutoPreload(t *testing.T) { - user1 := getPreloadUser("auto_user1") - DB.Save(user1) - - preloadDB := DB.Set("gorm:auto_preload", true).Where("role = ?", "Preload") - var user User - preloadDB.Find(&user) - checkUserHasPreloadData(user, t) - - user2 := getPreloadUser("auto_user2") - DB.Save(user2) - - var users []User - preloadDB.Find(&users) - - for _, user := range users { - checkUserHasPreloadData(user, t) - } - - var users2 []*User - preloadDB.Find(&users2) - - for _, user := range users2 { - checkUserHasPreloadData(*user, t) - } +func toJSONString(v interface{}) []byte { + r, _ := json.Marshal(v) + return r } func TestNestedPreload1(t *testing.T) { @@ -141,10 +34,8 @@ func TestNestedPreload1(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -162,7 +53,7 @@ func TestNestedPreload1(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2").Preload("Level2.Level1").First(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -185,10 +76,8 @@ func TestNestedPreload2(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -239,10 +128,8 @@ func TestNestedPreload3(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -284,10 +171,8 @@ func TestNestedPreload4(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -332,10 +217,8 @@ func TestNestedPreload5(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -377,10 +260,8 @@ func TestNestedPreload6(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -451,10 +332,8 @@ func TestNestedPreload7(t *testing.T) { Level2s []Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -507,10 +386,8 @@ func TestNestedPreload8(t *testing.T) { Level2 Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -558,9 +435,9 @@ func TestNestedPreload9(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint - Level2_1ID uint - Level0s []Level0 + Level2ID *uint + Level2_1ID *uint + Level0s []Level0 `json:",omitempty"` } Level2 struct { ID uint @@ -569,7 +446,7 @@ func TestNestedPreload9(t *testing.T) { } Level2_1 struct { ID uint - Level1s []Level1 + Level1s []Level1 `json:",omitempty"` Level3ID uint } Level3 struct { @@ -579,12 +456,8 @@ func TestNestedPreload9(t *testing.T) { Level2_1 Level2_1 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level2_1{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level0{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}); err != nil { t.Error(err) } @@ -641,7 +514,7 @@ func TestNestedPreload9(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if string(toJSONString(got)) != string(toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -654,7 +527,7 @@ type LevelA1 struct { type LevelA2 struct { ID uint Value string - LevelA3s []*LevelA3 + LevelA3s []*LevelA3 `json:",omitempty"` } type LevelA3 struct { @@ -667,11 +540,8 @@ type LevelA3 struct { } func TestNestedPreload10(t *testing.T) { - DB.DropTableIfExists(&LevelA3{}) - DB.DropTableIfExists(&LevelA2{}) - DB.DropTableIfExists(&LevelA1{}) - - if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelA3{}, &LevelA2{}, &LevelA1{}) + if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}); err != nil { t.Error(err) } @@ -706,7 +576,7 @@ func TestNestedPreload10(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -727,14 +597,12 @@ type LevelB3 struct { Value string LevelB1ID sql.NullInt64 LevelB1 *LevelB1 - LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s"` + LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s" json:",omitempty"` } func TestNestedPreload11(t *testing.T) { - DB.DropTableIfExists(&LevelB2{}) - DB.DropTableIfExists(&LevelB3{}) - DB.DropTableIfExists(&LevelB1{}) - if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelB3{}, &LevelB2{}, &LevelB1{}) + if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}); err != nil { t.Error(err) } @@ -746,6 +614,7 @@ func TestNestedPreload11(t *testing.T) { levelB3 := &LevelB3{ Value: "bar", LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, + LevelB2s: []*LevelB2{}, } if err := DB.Create(levelB3).Error; err != nil { t.Error(err) @@ -758,7 +627,7 @@ func TestNestedPreload11(t *testing.T) { t.Error(err) } - if !reflect.DeepEqual(got, want) { + if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } @@ -783,10 +652,8 @@ type LevelC3 struct { } func TestNestedPreload12(t *testing.T) { - DB.DropTableIfExists(&LevelC2{}) - DB.DropTableIfExists(&LevelC3{}) - DB.DropTableIfExists(&LevelC1{}) - if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}).Error; err != nil { + DB.Migrator().DropTable(&LevelC3{}, &LevelC2{}, &LevelC1{}) + if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}); err != nil { t.Error(err) } @@ -825,8 +692,8 @@ func TestNestedPreload12(t *testing.T) { } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" || dialect == "mssql" { - return + if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { + t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } type ( @@ -843,11 +710,10 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable(&Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -934,12 +800,10 @@ func TestManyToManyPreloadForNestedPointer(t *testing.T) { } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("levels") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1036,13 +900,9 @@ func TestNestedManyToManyPreload(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2", "level2_level3") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1078,7 +938,7 @@ func TestNestedManyToManyPreload(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2s.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1102,12 +962,10 @@ func TestNestedManyToManyPreload2(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1135,7 +993,7 @@ func TestNestedManyToManyPreload2(t *testing.T) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } - if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { + if err := DB.Preload("Level2.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } @@ -1159,12 +1017,9 @@ func TestNestedManyToManyPreload3(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1198,7 +1053,7 @@ func TestNestedManyToManyPreload3(t *testing.T) { } for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(want).Error; err != nil { t.Error(err) } } @@ -1234,12 +1089,10 @@ func TestNestedManyToManyPreload3ForStruct(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists("level1_level2") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + DB.Migrator().DropTable("level1_level2") - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1273,7 +1126,7 @@ func TestNestedManyToManyPreload3ForStruct(t *testing.T) { } for _, want := range wants { - if err := DB.Save(&want).Error; err != nil { + if err := DB.Save(want).Error; err != nil { t.Error(err) } } @@ -1314,12 +1167,8 @@ func TestNestedManyToManyPreload4(t *testing.T) { } ) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level4{}) - DB.DropTableIfExists("level1_level2") - DB.DropTableIfExists("level2_level3") + DB.Migrator().DropTable("level1_level2", "level2_level3") + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) dummy := Level1{ Value: "Level1", @@ -1334,7 +1183,7 @@ func TestNestedManyToManyPreload4(t *testing.T) { }}, } - if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1361,11 +1210,9 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } ) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") + DB.Migrator().DropTable("levels", &Level2{}, &Level1{}) - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1447,16 +1294,13 @@ func TestNilPointerSlice(t *testing.T) { Level1 struct { ID uint Value string - Level2ID uint + Level2ID *uint Level2 *Level2 } ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } @@ -1478,7 +1322,7 @@ func TestNilPointerSlice(t *testing.T) { Level2: nil, } if err := DB.Save(&want2).Error; err != nil { - t.Error(err) + t.Fatalf("Got error %v", err) } var got []Level1 @@ -1520,12 +1364,9 @@ func TestNilPointerSlice2(t *testing.T) { } ) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)).Error; err != nil { + if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)); err != nil { t.Error(err) } @@ -1555,7 +1396,7 @@ func TestPrefixedPreloadDuplication(t *testing.T) { Level3 struct { ID uint Name string - Level4s []*Level4 + Level4s []*Level4 `json:",omitempty"` } Level2 struct { ID uint @@ -1571,12 +1412,9 @@ func TestPrefixedPreloadDuplication(t *testing.T) { } ) - DB.DropTableIfExists(new(Level3)) - DB.DropTableIfExists(new(Level4)) - DB.DropTableIfExists(new(Level2)) - DB.DropTableIfExists(new(Level1)) + DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) - if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)).Error; err != nil { + if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)); err != nil { t.Error(err) } @@ -1622,12 +1460,52 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } + for _, level1 := range append(got, want...) { + sort.Slice(level1.Level2.Level3.Level4s, func(i, j int) bool { + return level1.Level2.Level3.Level4s[i].ID > level1.Level2.Level3.Level4s[j].ID + }) + } + if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } -func toJSONString(v interface{}) []byte { - r, _ := json.MarshalIndent(v, "", " ") - return r +func TestPreloadManyToManyCallbacks(t *testing.T) { + type ( + Level2 struct { + ID uint + Name string + } + Level1 struct { + ID uint + Name string + Level2s []Level2 `gorm:"many2many:level1_level2s"` + } + ) + + DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{}) + + if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { + t.Error(err) + } + + lvl := Level1{ + Name: "l1", + Level2s: []Level2{ + {Name: "l2-1"}, {Name: "l2-2"}, + }, + } + DB.Save(&lvl) + + var called int64 + DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { + atomic.AddInt64(&called, 1) + }) + + DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) + + if called != 3 { + t.Errorf("Wanted callback to be called 3 times but got %d", called) + } } diff --git a/tests/preload_test.go b/tests/preload_test.go new file mode 100644 index 00000000..8f49955e --- /dev/null +++ b/tests/preload_test.go @@ -0,0 +1,240 @@ +package tests_test + +import ( + "encoding/json" + "regexp" + "sort" + "strconv" + "sync" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestPreloadWithAssociations(t *testing.T) { + var user = *GetUser("preload_with_associations", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + CheckUser(t, user, user) + + var user2 User + DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + var user3 = *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) +} + +func TestNestedPreload(t *testing.T) { + var user = *GetUser("nested_preload", Config{Pets: 2}) + + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} + } + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var user2 User + DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + var user3 User + DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) + + var user4 *User + DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) + CheckUser(t, *user4, user) +} + +func TestNestedPreloadForSlice(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy").Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2 { + CheckUser(t, user, users[idx]) + } +} + +func TestPreloadWithConds(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Account: true}), + *GetUser("slice_nested_preload_2", Config{Account: false}), + *GetUser("slice_nested_preload_3", Config{Account: true}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Account", clause.Eq{Column: "number", Value: users[0].Account.Number}).Find(&users2, "id IN ?", userIDs) + sort.Slice(users2, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for idx, user := range users2[1:2] { + if user.Account.Number != "" { + t.Errorf("No account should found for user %v but got %v", idx+2, user.Account.Number) + } + } + + CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } +} + +func TestNestedPreloadWithConds(t *testing.T) { + var users = []User{ + *GetUser("slice_nested_preload_1", Config{Pets: 2}), + *GetUser("slice_nested_preload_2", Config{Pets: 0}), + *GetUser("slice_nested_preload_3", Config{Pets: 3}), + } + + for _, user := range users { + for idx, pet := range user.Pets { + pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} + } + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var userIDs []uint + for _, user := range users { + userIDs = append(userIDs, user.ID) + } + + var users2 []User + DB.Preload("Pets.Toy", "name like ?", `%preload_3`).Find(&users2, "id IN ?", userIDs) + + for idx, user := range users2[0:2] { + for _, pet := range user.Pets { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", idx+1, pet.Name, pet.Toy.Name) + } + } + } + + if len(users2[2].Pets) != 3 { + t.Errorf("Invalid pet toys found for user 3 got %v", len(users2[2].Pets)) + } else { + sort.Slice(users2[2].Pets, func(i, j int) bool { + return users2[2].Pets[i].ID < users2[2].Pets[j].ID + }) + + for _, pet := range users2[2].Pets[0:2] { + if pet.Toy.Name != "" { + t.Errorf("No toy should for user %v's pet %v but got %v", 3, pet.Name, pet.Toy.Name) + } + } + + CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) + } +} + +func TestPreloadEmptyData(t *testing.T) { + var user = *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} + +func TestPreloadGoroutine(t *testing.T) { + var wg sync.WaitGroup + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + var user2 []User + tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) + + if err := tx.Preload("Team").Find(&user2).Error; err != nil { + t.Error(err) + } + }() + } + wg.Wait() +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go new file mode 100644 index 00000000..8730e547 --- /dev/null +++ b/tests/prepared_stmt_test.go @@ -0,0 +1,90 @@ +package tests_test + +import ( + "context" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestPreparedStmt(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + txCtx := tx.WithContext(ctx) + + user := *GetUser("prepared_stmt", Config{}) + + txCtx.Create(&user) + + var result1 User + if err := txCtx.Find(&result1, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + time.Sleep(time.Second) + + var result2 User + if err := tx.Find(&result2, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + user2 := *GetUser("prepared_stmt2", Config{}) + if err := txCtx.Create(&user2).Error; err == nil { + t.Fatalf("should failed to create with timeout context") + } + + if err := tx.Create(&user2).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + var result3 User + if err := tx.Find(&result3, user2.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } +} + +func TestPreparedStmtFromTransaction(t *testing.T) { + db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) + + tx := db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + if err := tx.Error; err != nil { + t.Errorf("Failed to start transaction, got error %v\n", err) + } + + if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { + tx.Rollback() + t.Errorf("Failed to run one transaction, got error %v\n", err) + } + + if err := tx.Commit().Error; err != nil { + t.Errorf("Failed to commit transaction, got error %v\n", err) + } + + if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + + tx2 := db.Begin() + if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { + t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) + } + tx2.Commit() +} diff --git a/tests/query_test.go b/tests/query_test.go new file mode 100644 index 00000000..34999337 --- /dev/null +++ b/tests/query_test.go @@ -0,0 +1,1139 @@ +package tests_test + +import ( + "database/sql" + "fmt" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestFind(t *testing.T) { + var users = []User{ + *GetUser("find", Config{}), + *GetUser("find", Config{}), + *GetUser("find", Config{}), + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create users: %v", err) + } + + t.Run("First", func(t *testing.T) { + var first User + if err := DB.Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + CheckUser(t, first, users[0]) + } + }) + + t.Run("Last", func(t *testing.T) { + var last User + if err := DB.Where("name = ?", "find").Last(&last).Error; err != nil { + t.Errorf("errors happened when query last: %v", err) + } else { + CheckUser(t, last, users[2]) + } + }) + + var all []User + if err := DB.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { + t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) + } else { + for idx, user := range users { + t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, all[idx], user) + }) + } + } + + t.Run("FirstMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstPtrMap", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstSliceOfMap", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + + var models []User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } +} + +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + +func TestFindInBatches(t *testing.T) { + var users = []User{ + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + *GetUser("find_in_batches", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + + if tx.RowsAffected != 2 { + t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) + } + + if len(results) != 2 { + t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) + } + + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + + return nil + }); result.Error != nil || result.RowsAffected != 6 { + t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) + } + + if totalBatch != 6 { + t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) + } + + var count int64 + DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) + if count != 6 { + t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) + } +} + +func TestFindInBatchesWithError(t *testing.T) { + if name := DB.Dialector.Name(); name == "sqlserver" { + t.Skip("skip sqlserver due to it will raise data race for invalid sql") + } + + var users = []User{ + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + *GetUser("find_in_batches_with_error", Config{}), + } + + DB.Create(&users) + + var ( + results []User + totalBatch int + ) + + if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { + totalBatch += batch + return nil + }); result.Error == nil || result.RowsAffected > 0 { + t.Fatal("expected errors to have occurred, but nothing happened") + } + if totalBatch != 0 { + t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) + } +} + +func TestFillSmallerStruct(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + + var simpleUser SimpleUser + if err := DB.Table("users").Where("name = ?", user.Name).First(&simpleUser).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") + + var simpleUser2 SimpleUser + if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUser2, "ID") + + var simpleUsers []SimpleUser + if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { + t.Fatalf("Failed to query smaller user, got error %v", err) + } + + AssertObjEqual(t, user, simpleUsers[0], "ID") + + result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) + + if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } + + result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID) + + if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) + } +} + +func TestFillSmallerStructWithAllFields(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + var simpleUsers []SimpleUser + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + + result := dryDB.Model(&User{}).Find(&simpleUsers, user.ID) + if !regexp.MustCompile("SELECT .users.*id.*users.*name.*users.*updated_at.*users.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]*User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } +} + +func TestNot(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Not(map[string]interface{}{"name": "jinzhu"}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu1").Not("name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ AND NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestNotWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Not(map[string]interface{}{"users.name": "jinzhu"}).Find(&User{}) + + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu1").Not("users.name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ AND NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("users.name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestOr(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin").Or("role = ?", "admin")).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND (.*role.* = .+ OR .*role.* = .+)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestOrWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*users.*name.* AND .*users.*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + +func TestPluck(t *testing.T) { + users := []*User{ + GetUser("pluck-user1", Config{}), + GetUser("pluck-user2", Config{}), + GetUser("pluck-user3", Config{}), + } + + DB.Create(&users) + + var names []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + + var names2 []string + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { + t.Errorf("got error when pluck name: %v", err) + } + AssertEqual(t, names, sort.Reverse(sort.StringSlice(names2))) + + var ids []int + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { + t.Errorf("got error when pluck id: %v", err) + } + + for idx, name := range names { + if name != users[idx].Name { + t.Errorf("Unexpected result on pluck name, got %+v", names) + } + } + + for idx, id := range ids { + if int(id) != int(users[idx].ID) { + t.Errorf("Unexpected result on pluck id, got %+v", ids) + } + } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } +} + +func TestSelect(t *testing.T) { + user := User{Name: "SelectUser1"} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Select("name").Find(&result) + if result.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if user.Name != result.Name { + t.Errorf("Should have user Name when selected it") + } + + var result2 User + DB.Where("name = ?", user.Name).Select("name as name").Find(&result2) + if result2.ID != 0 { + t.Errorf("Should not have ID because only selected name, %+v", result2.ID) + } + + if user.Name != result2.Name { + t.Errorf("Should have user Name when selected it") + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + r := dryDB.Select("name", "age").Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with strings, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select([]string{"name", "age"}).Find(&User{}) + if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) + } + + // SELECT COALESCE(age,'42') FROM users; + r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + + // named arguments + r = dryDB.Table("users").Select("COALESCE(age, @default)", sql.Named("default", 42)).Find(&User{}) + if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) + } + + if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { + t.Fatalf("Failed, got error: %v", err) + } + + r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } +} + +func TestOmit(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var result User + DB.Where("name = ?", user.Name).Omit("name").Find(&result) + if result.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", result.ID) + } + + if result.Name != "" || result.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", result.Name, result.Age) + } +} + +func TestOmitWithAllFields(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var userResult User + DB.Session(&gorm.Session{QueryFields: true}).Where("users.name = ?", user.Name).Omit("name").Find(&userResult) + if userResult.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", userResult.ID) + } + + if userResult.Name != "" || userResult.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", userResult.Name, userResult.Age) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*birthday" + + ".*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Omit("name, age").Find(&User{}) + if !regexp.MustCompile(userQuery).MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL must include table name and selected fields, got %v", result.Statement.SQL.String()) + } +} + +func TestPluckWithSelect(t *testing.T) { + users := []User{ + {Name: "pluck_with_select_1", Age: 25}, + {Name: "pluck_with_select_2", Age: 26}, + } + + DB.Create(&users) + + var userAges []int + err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error + if err != nil { + t.Fatalf("got error when pluck user_age: %v", err) + } + + sort.Ints(userAges) + + AssertEqual(t, userAges, []int{26, 27}) +} + +func TestSelectWithVariables(t *testing.T) { + DB.Save(&User{Name: "select_with_variables"}) + + rows, _ := DB.Table("users").Where("name = ?", "select_with_variables").Select("? as fake", gorm.Expr("name")).Rows() + + if !rows.Next() { + t.Errorf("Should have returned at least one row") + } else { + columns, _ := rows.Columns() + AssertEqual(t, columns, []string{"fake"}) + } + + rows.Close() +} + +func TestSelectWithArrayInput(t *testing.T) { + DB.Save(&User{Name: "select_with_array", Age: 42}) + + var user User + DB.Select([]string{"name", "age"}).Where("age = 42 AND name = ?", "select_with_array").First(&user) + + if user.Name != "select_with_array" || user.Age != 42 { + t.Errorf("Should have selected both age and name") + } +} + +func TestCustomizedTypePrimaryKey(t *testing.T) { + type ID uint + type CustomizedTypePrimaryKey struct { + ID ID + Name string + } + + DB.Migrator().DropTable(&CustomizedTypePrimaryKey{}) + if err := DB.AutoMigrate(&CustomizedTypePrimaryKey{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + p1 := CustomizedTypePrimaryKey{Name: "p1"} + p2 := CustomizedTypePrimaryKey{Name: "p2"} + p3 := CustomizedTypePrimaryKey{Name: "p3"} + DB.Create(&p1) + DB.Create(&p2) + DB.Create(&p3) + + var p CustomizedTypePrimaryKey + + if err := DB.First(&p, p2.ID).Error; err != nil { + t.Errorf("No error should returns, but got %v", err) + } + + AssertEqual(t, p, p2) + + if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { + t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) + } + + AssertEqual(t, p, p2) +} + +func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { + type AddressByZipCode struct { + ZipCode string `gorm:"primary_key"` + Address string + } + + DB.Migrator().DropTable(&AddressByZipCode{}) + if err := DB.AutoMigrate(&AddressByZipCode{}); err != nil { + t.Fatalf("failed to migrate, got error %v", err) + } + + address := AddressByZipCode{ZipCode: "00501", Address: "Holtsville"} + DB.Create(&address) + + var result AddressByZipCode + DB.First(&result, "00501") + + AssertEqual(t, result, address) +} + +func TestSearchWithEmptyChain(t *testing.T) { + user := User{Name: "search_with_empty_chain", Age: 1} + DB.Create(&user) + + var result User + if DB.Where("").Where("").First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty strings") + } + + result = User{} + if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty struct") + } + + result = User{} + if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { + t.Errorf("Should not raise any error if searching with empty map") + } +} + +func TestOrder(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryDB.Order("age desc, name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc").Order("name").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + +func TestOrderWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name.*users.*age" + + ".*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Order("users.age desc, users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "users.age desc, users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("users.age desc").Order("users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "ORDER BY users.age desc,users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(userQuery + "ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + +func TestLimit(t *testing.T) { + users := []User{ + {Name: "LimitUser1", Age: 1}, + {Name: "LimitUser2", Age: 10}, + {Name: "LimitUser3", Age: 20}, + {Name: "LimitUser4", Age: 10}, + {Name: "LimitUser5", Age: 20}, + {Name: "LimitUser6", Age: 20}, + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) + + if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { + t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) + } +} + +func TestOffset(t *testing.T) { + for i := 0; i < 20; i++ { + DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) + } + var users1, users2, users3, users4 []User + + DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work") + } + + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) + + if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { + t.Errorf("Offset should work without limit.") + } +} + +func TestSearchWithMap(t *testing.T) { + users := []User{ + *GetUser("map_search_user1", Config{}), + *GetUser("map_search_user2", Config{}), + *GetUser("map_search_user3", Config{}), + *GetUser("map_search_user4", Config{Company: true}), + } + + DB.Create(&users) + + var user User + DB.First(&user, map[string]interface{}{"name": users[0].Name}) + CheckUser(t, user, users[0]) + + user = User{} + DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) + CheckUser(t, user, users[1]) + + var results []User + DB.Where(map[string]interface{}{"name": users[2].Name}).Find(&results) + if len(results) != 1 { + t.Fatalf("Search all records with inline map") + } + + CheckUser(t, results[0], users[2]) + + var results2 []User + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": nil}) + if len(results2) != 0 { + t.Errorf("Search all records with inline map containing null value finding 0 records") + } + + DB.Find(&results2, map[string]interface{}{"name": users[0].Name, "company_id": nil}) + if len(results2) != 1 { + t.Errorf("Search all records with inline map containing null value finding 1 record") + } + + DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": users[3].CompanyID}) + if len(results2) != 1 { + t.Errorf("Search all records with inline multiple value map") + } +} + +func TestSearchWithStruct(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } + + result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) + if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestSubQuery(t *testing.T) { + users := []User{ + {Name: "subquery_1", Age: 10}, + {Name: "subquery_2", Age: 20}, + {Name: "subquery_3", Age: 30}, + {Name: "subquery_4", Age: 40}, + } + + DB.Create(&users) + + if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { + t.Fatalf("got error: %v", err) + } + + if len(users) != 4 { + t.Errorf("Four users should be found, instead found %d", len(users)) + } + + DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. + Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) + + if len(users) != 2 { + t.Errorf("Two users should be found, instead found %d", len(users)) + } +} + +func TestSubQueryWithRaw(t *testing.T) { + users := []User{ + {Name: "subquery_raw_1", Age: 10}, + {Name: "subquery_raw_2", Age: 20}, + {Name: "subquery_raw_3", Age: 30}, + {Name: "subquery_raw_4", Age: 40}, + } + DB.Create(&users) + + var count int64 + err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 1 { + t.Errorf("Row count must be 1, instead got %d", count) + } + + err = DB.Raw("select count(*) from (?) tmp", + DB.Table("users"). + Select("name"). + Where("name LIKE ?", "subquery_raw%"). + Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). + Group("name"), + ).Count(&count).Error + + if err != nil { + t.Errorf("Expected to get no errors, but got %v", err) + } + + if count != 2 { + t.Errorf("Row count must be 2, instead got %d", count) + } +} + +func TestSubQueryWithHaving(t *testing.T) { + users := []User{ + {Name: "subquery_having_1", Age: 10}, + {Name: "subquery_having_2", Age: 20}, + {Name: "subquery_having_3", Age: 30}, + {Name: "subquery_having_4", Age: 40}, + } + DB.Create(&users) + + var results []User + DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. + Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) + + if len(results) != 2 { + t.Errorf("Two user group should be found, instead found %d", len(results)) + } +} + +func TestScanNullValue(t *testing.T) { + user := GetUser("scan_null_value", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var result User + if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { + t.Fatalf("failed to query struct data with null age, got error %v", err) + } + + AssertEqual(t, result, user) + + users := []User{ + *GetUser("scan_null_value_for_slice_1", Config{}), + *GetUser("scan_null_value_for_slice_2", Config{}), + *GetUser("scan_null_value_for_slice_3", Config{}), + } + DB.Create(&users) + + if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var results []User + if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { + t.Fatalf("failed to query slice data with null age, got error %v", err) + } +} + +func TestQueryWithTableAndConditions(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + + if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} + +func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* FROM .user. " + + if !regexp.MustCompile(userQuery + `WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go new file mode 100644 index 00000000..86cb0399 --- /dev/null +++ b/tests/scan_test.go @@ -0,0 +1,110 @@ +package tests_test + +import ( + "reflect" + "sort" + "strings" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestScan(t *testing.T) { + user1 := User{Name: "ScanUser1", Age: 1} + user2 := User{Name: "ScanUser2", Age: 10} + user3 := User{Name: "ScanUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + ID uint + Name string + Age int + } + + var res result + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + var resPointer *result + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + var doubleAgeRes = &result{} + if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { + t.Errorf("Scan to pointer of pointer") + } + + if doubleAgeRes.Age != int(res.Age)*2 { + t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) + } + + var results []result + DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Scan into struct map, got %#v", results) + } +} + +func TestScanRows(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1} + user2 := User{Name: "ScanRowsUser2", Age: 10} + user3 := User{Name: "ScanRowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + for rows.Next() { + var result Result + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { + t.Errorf("Should find expected results") + } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } +} diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go new file mode 100644 index 00000000..fb1f5791 --- /dev/null +++ b/tests/scanner_valuer_test.go @@ -0,0 +1,393 @@ +package tests_test + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "reflect" + "regexp" + "strconv" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestScannerValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + Male: sql.NullBool{Bool: true, Valid: true}, + Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, + Password: EncryptedData("pass1"), + Bytes: []byte("byte"), + Num: 18, + Strings: StringsSlice{"a", "b", "c"}, + Structs: StructsSlice{ + {"name1", "value1"}, + {"name2", "value2"}, + }, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + if err := DB.Create(&data).Error; err != nil { + t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err) + } + + var result ScannerValuerStruct + + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { + t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) + } + + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } + AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") +} + +func TestScannerValuerWithFirstOrCreate(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + var result ScannerValuerStruct + tx := DB.Where(data).FirstOrCreate(&result) + + if tx.RowsAffected != 1 { + t.Errorf("RowsAffected should be 1 after create some record") + } + + if tx.Error != nil { + t.Errorf("Should not raise any error, but got %v", tx.Error) + } + + AssertObjEqual(t, result, data, "Name", "Gender", "Age") + + if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil { + t.Errorf("Should not raise any error, but got %v", err) + } + + if result.Age.Int64 != 18 { + t.Errorf("should update age to 18") + } + + var result2 ScannerValuerStruct + if err := DB.First(&result2, result.ID).Error; err != nil { + t.Errorf("got error %v when query with %v", err, result.ID) + } + + AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age") +} + +func TestInvalidValuer(t *testing.T) { + DB.Migrator().DropTable(&ScannerValuerStruct{}) + if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { + t.Errorf("no error should happen when migrate scanner, valuer struct") + } + + data := ScannerValuerStruct{ + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, + } + + if err := DB.Create(&data).Error; err == nil { + t.Errorf("Should failed to create data with invalid data") + } + + data.Password = EncryptedData("pass1") + if err := DB.Create(&data).Error; err != nil { + t.Errorf("Should got no error when creating data, but got %v", err) + } + + if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { + t.Errorf("Should failed to update data with invalid data") + } + + if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { + t.Errorf("Should got no error update data with valid data, but got %v", err) + } + + AssertEqual(t, data.Password, EncryptedData("newpass")) +} + +type ScannerValuerStruct struct { + gorm.Model + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Allergen NullString + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct +} + +type EncryptedData []byte + +func (data *EncryptedData) Scan(value interface{}) error { + if b, ok := value.([]byte); ok { + if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { + return errors.New("Too short") + } + + *data = b[3:] + return nil + } else if s, ok := value.(string); ok { + *data = []byte(s)[3:] + return nil + } + + return errors.New("Bytes expected") +} + +func (data EncryptedData) Value() (driver.Value, error) { + if len(data) > 0 && data[0] == 'x' { + //needed to test failures + return nil, errors.New("Should not start with 'x'") + } + + //prepend asterisks + return append([]byte("***"), data...), nil +} + +type Num int64 + +func (i *Num) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + n, _ := strconv.Atoi(string(s)) + *i = Num(n) + case int64: + *i = Num(s) + default: + return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) + } + return nil +} + +type StringsSlice []string + +func (l StringsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StringsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type ExampleStruct struct { + Name string + Val string +} + +func (ExampleStruct) GormDataType() string { + return "bytes" +} + +func (s ExampleStruct) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + // for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } +} + +type StructsSlice []ExampleStruct + +func (l StructsSlice) Value() (driver.Value, error) { + bytes, err := json.Marshal(l) + return string(bytes), err +} + +func (l *StructsSlice) Scan(input interface{}) error { + switch value := input.(type) { + case string: + return json.Unmarshal([]byte(value), l) + case []byte: + return json.Unmarshal(value, l) + default: + return errors.New("not supported") + } +} + +type Role struct { + Name string `gorm:"size:256"` +} + +func (role *Role) Scan(value interface{}) error { + if b, ok := value.([]uint8); ok { + role.Name = string(b) + } else { + role.Name = value.(string) + } + return nil +} + +func (role Role) Value() (driver.Value, error) { + return role.Name, nil +} + +func (role Role) IsAdmin() bool { + return role.Name == "admin" +} + +type EmptyTime struct { + time.Time +} + +func (t *EmptyTime) Scan(v interface{}) error { + nullTime := sql.NullTime{} + err := nullTime.Scan(v) + t.Time = nullTime.Time + return err +} + +func (t EmptyTime) Value() (driver.Value, error) { + return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil +} + +type NullString struct { + sql.NullString +} + +type Point struct { + X, Y int +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} diff --git a/tests/scopes_test.go b/tests/scopes_test.go new file mode 100644 index 00000000..9836c41e --- /dev/null +++ b/tests/scopes_test.go @@ -0,0 +1,57 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func NameIn1And2(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) +} + +func NameIn2And3(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) +} + +func NameIn(names []string) func(d *gorm.DB) *gorm.DB { + return func(d *gorm.DB) *gorm.DB { + return d.Where("name in (?)", names) + } +} + +func TestScopes(t *testing.T) { + var users = []*User{ + GetUser("ScopeUser1", Config{}), + GetUser("ScopeUser2", Config{}), + GetUser("ScopeUser3", Config{}), + } + + DB.Create(&users) + + var users1, users2, users3 []User + DB.Scopes(NameIn1And2).Find(&users1) + if len(users1) != 2 { + t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) + } + + DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) + if len(users2) != 1 { + t.Errorf("Should found one user's name is 2, but got %v", len(users2)) + } + + DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) + if len(users3) != 2 { + t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) + } + + db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { + return tx.Table("custom_table") + }).Session(&gorm.Session{}) + + db.AutoMigrate(&User{}) + if db.Find(&User{}).Statement.Table != "custom_table" { + t.Errorf("failed to call Scopes") + } +} diff --git a/tests/soft_delete_test.go b/tests/soft_delete_test.go new file mode 100644 index 00000000..0dfe24d5 --- /dev/null +++ b/tests/soft_delete_test.go @@ -0,0 +1,85 @@ +package tests_test + +import ( + "database/sql" + "encoding/json" + "errors" + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestSoftDelete(t *testing.T) { + user := *GetUser("SoftDelete", Config{}) + DB.Save(&user) + + var count int64 + var age uint + + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) + } + + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + if err := DB.Delete(&user).Error; err != nil { + t.Fatalf("No error should happen when soft delete user, but got %v", err) + } + + if sql.NullTime(user.DeletedAt).Time.IsZero() { + t.Fatalf("user's deleted at is zero") + } + + sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() + if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + if DB.First(&User{}, "name = ?", user.Name).Error == nil { + t.Errorf("Can't find a soft deleted record") + } + + count = 0 + if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { + t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) + } + + age = 0 + if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) + } + + count = 0 + if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { + t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) + } + + age = 0 + if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { + t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) + } + + DB.Unscoped().Delete(&user) + if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("Can't find permanently deleted record") + } +} + +func TestDeletedAtUnMarshal(t *testing.T) { + expected := &gorm.Model{} + b, _ := json.Marshal(expected) + + result := &gorm.Model{} + _ = json.Unmarshal(b, result) + if result.DeletedAt != expected.DeletedAt { + t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) + } +} diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go new file mode 100644 index 00000000..081b96c9 --- /dev/null +++ b/tests/sql_builder_test.go @@ -0,0 +1,289 @@ +package tests_test + +import ( + "regexp" + "strings" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestRow(t *testing.T) { + user1 := User{Name: "RowUser1", Age: 1} + user2 := User{Name: "RowUser2", Age: 10} + user3 := User{Name: "RowUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() + + var age int64 + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 10 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } + + table := "gorm.users" + if DB.Dialector.Name() != "mysql" { + table = "users" // other databases doesn't support select with `database.table` + } + + DB.Table(table).Where(map[string]interface{}{"name": user2.Name}).Update("age", 20) + + row = DB.Table(table+" as u").Where("u.name = ?", user2.Name).Select("age").Row() + if err := row.Scan(&age); err != nil { + t.Fatalf("Failed to scan age, got %v", err) + } + + if age != 20 { + t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) + } +} + +func TestRows(t *testing.T) { + user1 := User{Name: "RowsUser1", Age: 1} + user2 := User{Name: "RowsUser2", Age: 10} + user3 := User{Name: "RowsUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + count := 0 + for rows.Next() { + var name string + var age int64 + rows.Scan(&name, &age) + count++ + } + + if count != 2 { + t.Errorf("Should found two records") + } +} + +func TestRaw(t *testing.T) { + user1 := User{Name: "ExecRawSqlUser1", Age: 1} + user2 := User{Name: "ExecRawSqlUser2", Age: 10} + user3 := User{Name: "ExecRawSqlUser3", Age: 20} + DB.Save(&user1).Save(&user2).Save(&user3) + + type result struct { + Name string + Email string + } + + var results []result + DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&results) + if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { + t.Errorf("Raw with scan") + } + + rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() + count := 0 + for rows.Next() { + count++ + } + if count != 1 { + t.Errorf("Raw with Rows should find one record with name 3") + } + + DB.Exec("update users set name=? where name in (?)", "jinzhu-raw", []string{user1.Name, user2.Name, user3.Name}) + if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { + t.Error("Raw sql to update records") + } + + DB.Exec("update users set age=? where name = ?", gorm.Expr("age * ? + ?", 2, 10), "jinzhu-raw") + + var age int + DB.Raw("select sum(age) from users where name = ?", "jinzhu-raw").Scan(&age) + + if age != ((1+10+20)*2 + 30) { + t.Errorf("Invalid age, got %v", age) + } +} + +func TestRowsWithGroup(t *testing.T) { + users := []User{ + {Name: "having_user_1", Age: 1}, + {Name: "having_user_2", Age: 10}, + {Name: "having_user_1", Age: 20}, + {Name: "having_user_1", Age: 30}, + } + + DB.Create(&users) + + rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN ?", []string{users[0].Name, users[1].Name}).Rows() + if err != nil { + t.Fatalf("got error %v", err) + } + + defer rows.Close() + for rows.Next() { + var name string + var total int64 + rows.Scan(&name, &total) + + if name == users[0].Name && total != 3 { + t.Errorf("Should have one user having name %v", users[0].Name) + } else if name == users[1].Name && total != 1 { + t.Errorf("Should have two users having name %v", users[1].Name) + } + } +} + +func TestQueryRaw(t *testing.T) { + users := []*User{ + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + GetUser("row_query_user", Config{}), + } + DB.Create(&users) + + var user User + DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) + CheckUser(t, user, *users[1]) +} + +func TestDryRun(t *testing.T) { + user := *GetUser("dry-run", Config{}) + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&user).Statement + if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement + if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 { + t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) + } +} + +func TestGroupConditions(t *testing.T) { + type Pizza struct { + ID uint + Name string + Size string + } + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Where( + DB.Where("pizza = ?", "pepperoni").Where(DB.Where("size = ?", "small").Or("size = ?", "medium")), + ).Or( + DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), + ).Find(&Pizza{}).Statement + + execStmt := dryRunDB.Exec("WHERE (pizza = ? AND (size = ? OR size = ?)) OR (pizza = ? AND size = ?)", "pepperoni", "small", "medium", "hawaiian", "xlarge").Statement + + result := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + expects := DB.Dialector.Explain(execStmt.SQL.String(), execStmt.Vars...) + + if !strings.HasSuffix(result, expects) { + t.Errorf("expects: %v, got %v", expects, result) + } +} + +func TestCombineStringConditions(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } +} + +func TestFromWithJoins(t *testing.T) { + var result User + + newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") + + newDB.Clauses( + clause.From{ + Tables: []clause.Table{{Name: "users"}}, + Joins: []clause.Join{ + { + Table: clause.Table{Name: "companies", Raw: false}, + ON: clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{ + Table: "users", + Name: "company_id", + }, + Value: clause.Column{ + Table: "companies", + Name: "id", + }, + }, + }, + }, + }, + }, + }, + ) + + newDB.Joins("inner join rgs on rgs.id = user.id") + + stmt := newDB.First(&result).Statement + str := stmt.SQL.String() + + if !strings.Contains(str, "rgs.id = user.id") { + t.Errorf("The second join condition is over written instead of combining") + } + + if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { + t.Errorf("The first join condition is over written instead of combining") + } +} diff --git a/tests/table_test.go b/tests/table_test.go new file mode 100644 index 00000000..0c6b3eb0 --- /dev/null +++ b/tests/table_test.go @@ -0,0 +1,127 @@ +package tests_test + +import ( + "regexp" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +type UserWithTable struct { + gorm.Model + Name string +} + +func (UserWithTable) TableName() string { + return "gorm.user" +} + +func TestTable(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true}) + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} + +func TestTableWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* " + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile(userQuery + "FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + userQueryCharacter := "SELECT .*u.*id.*u.*created_at.*u.*updated_at.*u.*deleted_at.*u.*name.*u.*age.*u.*birthday" + + ".*u.*company_id.*u.*manager_id.*u.*active.* " + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh new file mode 100755 index 00000000..e0ed97a4 --- /dev/null +++ b/tests/tests_all.sh @@ -0,0 +1,41 @@ +#!/bin/bash -e + +dialects=("sqlite" "mysql" "postgres" "sqlserver") + +if [[ $(pwd) == *"gorm/tests"* ]]; then + cd .. +fi + +if [ -d tests ] +then + cd tests + go get -u ./... + go mod download + cd .. +fi + +for dialect in "${dialects[@]}" ; do + if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] + then + echo "testing ${dialect}..." + + if [ "$GORM_VERBOSE" = "" ] + then + GORM_DIALECT=${dialect} go test -race -count=1 ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test -race -count=1 ./... + cd .. + fi + else + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... + if [ -d tests ] + then + cd tests + GORM_DIALECT=${dialect} go test -race -count=1 -v ./... + cd .. + fi + fi + fi +done diff --git a/tests/tests_test.go b/tests/tests_test.go new file mode 100644 index 00000000..cb73d267 --- /dev/null +++ b/tests/tests_test.go @@ -0,0 +1,112 @@ +package tests_test + +import ( + "log" + "math/rand" + "os" + "path/filepath" + "time" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + "gorm.io/gorm/logger" + . "gorm.io/gorm/utils/tests" +) + +var DB *gorm.DB + +func init() { + var err error + if DB, err = OpenTestConnection(); err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } else { + sqlDB, err := DB.DB() + if err == nil { + err = sqlDB.Ping() + } + + if err != nil { + log.Printf("failed to connect database, got error %v", err) + } + + RunMigrations() + if DB.Dialector.Name() == "sqlite" { + DB.Exec("PRAGMA foreign_keys = ON") + } + } +} + +func OpenTestConnection() (db *gorm.DB, err error) { + dbDSN := os.Getenv("GORM_DSN") + switch os.Getenv("GORM_DIALECT") { + case "mysql": + log.Println("testing mysql...") + if dbDSN == "" { + dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + } + db, err = gorm.Open(mysql.Open(dbDSN), &gorm.Config{}) + case "postgres": + log.Println("testing postgres...") + if dbDSN == "" { + dbDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" + } + db, err = gorm.Open(postgres.New(postgres.Config{ + DSN: dbDSN, + PreferSimpleProtocol: true, + }), &gorm.Config{}) + case "sqlserver": + // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; + // CREATE DATABASE gorm; + // USE gorm; + // CREATE USER gorm FROM LOGIN gorm; + // sp_changedbowner 'gorm'; + // npm install -g sql-cli + // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 + log.Println("testing sqlserver...") + if dbDSN == "" { + dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + } + db, err = gorm.Open(sqlserver.Open(dbDSN), &gorm.Config{}) + default: + log.Println("testing sqlite3...") + db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), &gorm.Config{}) + } + + if debug := os.Getenv("DEBUG"); debug == "true" { + db.Logger = db.Logger.LogMode(logger.Info) + } else if debug == "false" { + db.Logger = db.Logger.LogMode(logger.Silent) + } + + return +} + +func RunMigrations() { + var err error + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}} + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) + + DB.Migrator().DropTable("user_friends", "user_speaks") + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } + } +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go new file mode 100644 index 00000000..4e4b6149 --- /dev/null +++ b/tests/transaction_test.go @@ -0,0 +1,369 @@ +package tests_test + +import ( + "context" + "errors" + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestTransaction(t *testing.T) { + tx := DB.Begin() + user := *GetUser("transaction", Config{}) + + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", "transaction").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + user1 := *GetUser("transaction1-1", Config{}) + + if err := tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { + t.Fatalf("Should return the underlying sql.Tx") + } + + tx.Rollback() + + if err := DB.First(&User{}, "name = ?", "transaction").Error; err == nil { + t.Fatalf("Should not find record after rollback, but got %v", err) + } + + txDB := DB.Where("fake_name = ?", "fake_name") + tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() + user2 := *GetUser("transaction-2", Config{}) + if err := tx2.Save(&user2).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + + tx2.Commit() + + if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { + t.Fatalf("Should be able to find committed record, but got %v", err) + } +} + +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + +func TestTransactionWithBlock(t *testing.T) { + assertPanic := func(f func()) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("The code did not panic") + } + }() + f() + } + + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Fatalf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transaction-block").Error; err == nil { + t.Fatalf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block-2", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transaction-block-2").Error; err != nil { + t.Fatalf("Should be able to find committed record") + } + + // panic will rollback + assertPanic(func() { + DB.Transaction(func(tx *gorm.DB) error { + user := *GetUser("transaction-block-3", Config{}) + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + panic("force panic") + }) + }) + + if err := DB.First(&User{}, "name = ?", "transaction-block-3").Error; err == nil { + t.Fatalf("Should not find record after panic rollback") + } +} + +func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { + tx := DB.Begin() + user := User{Name: "transaction"} + if err := tx.Save(&user).Error; err != nil { + t.Fatalf("No error should raise") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Commit should not raise error") + } + + if err := tx.Rollback().Error; err == nil { + t.Fatalf("Rollback after commit should raise error") + } +} + +func TestTransactionWithSavePoint(t *testing.T) { + tx := DB.Begin() + + user := *GetUser("transaction-save-point", Config{}) + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.SavePoint("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user1 := *GetUser("transaction-save-point-1", Config{}) + tx.Create(&user1) + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.RollbackTo("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.SavePoint("save_point2").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user2 := *GetUser("transaction-save-point-2", Config{}) + tx.Create(&user2) + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Failed to commit, got error %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestNestedTransactionWithBlock(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestDisabledNestedTransaction(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +} diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go new file mode 100644 index 00000000..736dfc5b --- /dev/null +++ b/tests/update_belongs_to_test.go @@ -0,0 +1,44 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateBelongsTo(t *testing.T) { + var user = *GetUser("update-belongs-to", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Company = Company{Name: "company-belongs-to-association"} + user.Manager = &User{Name: "manager-belongs-to-association"} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) +} diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go new file mode 100644 index 00000000..9066cbac --- /dev/null +++ b/tests/update_has_many_test.go @@ -0,0 +1,82 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateHasManyAssociations(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Pets").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + + t.Run("Polymorphic", func(t *testing.T) { + var user = *GetUser("update-has-many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Toys").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + }) +} diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go new file mode 100644 index 00000000..a61629f8 --- /dev/null +++ b/tests/update_has_one_test.go @@ -0,0 +1,88 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateHasOne(t *testing.T) { + var user = *GetUser("update-has-one", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Account = Account{Number: "account-has-one-association"} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Account").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + + CheckUser(t, user2, user3) + var lastUpdatedAt = user2.Account.UpdatedAt + time.Sleep(time.Second) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + + if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { + t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) + } else { + user.Account.UpdatedAt = user4.Account.UpdatedAt + CheckUser(t, user4, user) + } + + t.Run("Polymorphic", func(t *testing.T) { + var pet = Pet{Name: "create"} + + if err := DB.Create(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} + + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var pet2 Pet + DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) + CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) + }) +} diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go new file mode 100644 index 00000000..d94ef4ab --- /dev/null +++ b/tests/update_many2many_test.go @@ -0,0 +1,54 @@ +package tests_test + +import ( + "testing" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdateMany2ManyAssociations(t *testing.T) { + var user = *GetUser("update-many2many", Config{}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} + for _, lang := range user.Languages { + DB.Create(&lang) + } + user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user2 User + DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) + CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) +} diff --git a/tests/update_test.go b/tests/update_test.go new file mode 100644 index 00000000..5ad1bb39 --- /dev/null +++ b/tests/update_test.go @@ -0,0 +1,687 @@ +package tests_test + +import ( + "errors" + "regexp" + "sort" + "strings" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/utils" + . "gorm.io/gorm/utils/tests" +) + +func TestUpdate(t *testing.T) { + var ( + users = []*User{ + GetUser("update-1", Config{}), + GetUser("update-2", Config{}), + GetUser("update-3", Config{}), + } + user = users[1] + lastUpdatedAt time.Time + ) + + checkUpdatedAtChanged := func(name string, n time.Time) { + if n.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) + } + lastUpdatedAt = n + } + + checkOtherData := func(name string) { + var first, last User + if err := DB.Where("id = ?", users[0].ID).First(&first).Error; err != nil { + t.Errorf("errors happened when query before user: %v", err) + } + CheckUser(t, first, *users[0]) + + if err := DB.Where("id = ?", users[2].ID).First(&last).Error; err != nil { + t.Errorf("errors happened when query after user: %v", err) + } + CheckUser(t, last, *users[2]) + } + + if err := DB.Create(&users).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } else if user.ID == 0 { + t.Fatalf("user's primary value should not zero, %v", user.ID) + } else if user.UpdatedAt.IsZero() { + t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) + } + lastUpdatedAt = user.UpdatedAt + + if err := DB.Model(user).Update("Age", 10).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 10 { + t.Errorf("Age should equals to 10, but got %v", user.Age) + } + checkUpdatedAtChanged("Update", user.UpdatedAt) + checkOtherData("Update") + + var result User + if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result, *user) + } + + values := map[string]interface{}{"Active": true, "age": 5} + if err := DB.Model(user).Updates(values).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 5 { + t.Errorf("Age should equals to 5, but got %v", user.Age) + } else if user.Active != true { + t.Errorf("Active should be true, but got %v", user.Active) + } + checkUpdatedAtChanged("Updates with map", user.UpdatedAt) + checkOtherData("Updates with map") + + var result2 User + if err := DB.Where("id = ?", user.ID).First(&result2).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result2, *user) + } + + if err := DB.Model(user).Updates(User{Age: 2}).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 2 { + t.Errorf("Age should equals to 2, but got %v", user.Age) + } + checkUpdatedAtChanged("Updates with struct", user.UpdatedAt) + checkOtherData("Updates with struct") + + var result3 User + if err := DB.Where("id = ?", user.ID).First(&result3).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result3, *user) + } + + user.Active = false + user.Age = 1 + if err := DB.Save(user).Error; err != nil { + t.Errorf("errors happened when update: %v", err) + } else if user.Age != 1 { + t.Errorf("Age should equals to 1, but got %v", user.Age) + } else if user.Active != false { + t.Errorf("Active should equals to false, but got %v", user.Active) + } + checkUpdatedAtChanged("Save", user.UpdatedAt) + checkOtherData("Save") + + var result4 User + if err := DB.Where("id = ?", user.ID).First(&result4).Error; err != nil { + t.Errorf("errors happened when query: %v", err) + } else { + CheckUser(t, result4, *user) + } +} + +func TestUpdates(t *testing.T) { + var users = []*User{ + GetUser("updates_01", Config{}), + GetUser("updates_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[0].UpdatedAt + + // update with map + DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}) + if users[0].Name != "updates_01_newname" || users[0].Age != 100 { + t.Errorf("Record should be updated also with map") + } + + if users[0].UpdatedAt.UnixNano() == lastUpdatedAt.UnixNano() { + t.Errorf("User's updated at should be changed, but got %v, was %v", users[0].UpdatedAt.UnixNano(), lastUpdatedAt) + } + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + // update with struct + time.Sleep(1 * time.Second) + DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) + + var user3 User + if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { + t.Errorf("User2's name should be updated") + } + + if user2.UpdatedAt.Format(time.RFC1123Z) == user3.UpdatedAt.Format(time.RFC1123Z) { + t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123Z), user3.UpdatedAt.Format(time.RFC1123Z)) + } + + // update with gorm exprs + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } + var user4 User + DB.First(&user4, user3.ID) + + user3.Age += 100 + AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") +} + +func TestUpdateColumn(t *testing.T) { + var users = []*User{ + GetUser("update_column_01", Config{}), + GetUser("update_column_02", Config{}), + } + + DB.Create(&users) + lastUpdatedAt := users[1].UpdatedAt + + // update with map + DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100}) + if users[1].Name != "update_column_02_newname" || users[1].Age != 100 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user1, user2 User + DB.First(&user1, users[0].ID) + DB.First(&user2, users[1].ID) + CheckUser(t, user1, *users[0]) + CheckUser(t, user2, *users[1]) + + DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew") + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + if users[1].Name != "update_column_02_newnew" { + t.Errorf("user 2's name should be updated, but got %v", users[1].Name) + } + + DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) + var user3 User + DB.First(&user3, users[1].ID) + + users[1].Age += 50 + CheckUser(t, user3, *users[1]) + + // update with struct + DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200}) + if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 { + t.Errorf("user 2 should be updated with update column") + } + AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) + + // user2 should not be updated + var user5, user6 User + DB.First(&user5, users[0].ID) + DB.First(&user6, users[1].ID) + CheckUser(t, user5, *users[0]) + CheckUser(t, user6, *users[1]) +} + +func TestBlockGlobalUpdate(t *testing.T) { + if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) + } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } +} + +func TestSelectWithUpdate(t *testing.T) { + user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") + + DB.Model(&result).Select("Name", "Age").Updates(User{Name: "update_with_select"}) + if result.Age != 0 || result.Name != "update_with_select" { + t.Fatalf("Failed to update struct with select, got %+v", result) + } + AssertObjEqual(t, result, user, "UpdatedAt") + + var result3 User + DB.First(&result3, result.ID) + AssertObjEqual(t, result, result3, "Name", "Age", "UpdatedAt") + + DB.Model(&result).Select("Name", "Age", "UpdatedAt").Updates(User{Name: "update_with_select"}) + + if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { + t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) + } +} + +func TestSelectWithUpdateWithMap(t *testing.T) { + user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("name", "updated_at").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Languages = append(user.Languages, result.Languages...) + result.Toys = append(user.Toys, result.Toys...) + + sort.Slice(result.Languages, func(i, j int) bool { + return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 + }) + + sort.Slice(result.Toys, func(i, j int) bool { + return result.Toys[i].ID < result.Toys[j].ID + }) + + sort.Slice(result2.Languages, func(i, j int) bool { + return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 + }) + + sort.Slice(result2.Toys, func(i, j int) bool { + return result2.Toys[i].ID < result2.Toys[j].ID + }) + + AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") +} + +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + +func TestOmitWithUpdate(t *testing.T) { + user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + result.Name = user2.Name + result.Age = 50 + result.Account = user2.Account + result.Pets = user2.Pets + result.Toys = user2.Toys + result.Company = user2.Company + result.Manager = user2.Manager + result.Team = user2.Team + result.Languages = user2.Languages + result.Friends = user2.Friends + + DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestOmitWithUpdateWithMap(t *testing.T) { + user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + var result User + DB.First(&result, user.ID) + + user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + updateValues := map[string]interface{}{ + "Name": user2.Name, + "Age": 50, + "Account": user2.Account, + "Pets": user2.Pets, + "Toys": user2.Toys, + "Company": user2.Company, + "Manager": user2.Manager, + "Team": user2.Team, + "Languages": user2.Languages, + "Friends": user2.Friends, + } + + DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) + + var result2 User + DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) + + result.Pets = append(user.Pets, result.Pets...) + result.Team = append(user.Team, result.Team...) + result.Friends = append(user.Friends, result.Friends...) + + sort.Slice(result.Pets, func(i, j int) bool { + return result.Pets[i].ID < result.Pets[j].ID + }) + sort.Slice(result.Team, func(i, j int) bool { + return result.Team[i].ID < result.Team[j].ID + }) + sort.Slice(result.Friends, func(i, j int) bool { + return result.Friends[i].ID < result.Friends[j].ID + }) + sort.Slice(result2.Pets, func(i, j int) bool { + return result2.Pets[i].ID < result2.Pets[j].ID + }) + sort.Slice(result2.Team, func(i, j int) bool { + return result2.Team[i].ID < result2.Team[j].ID + }) + sort.Slice(result2.Friends, func(i, j int) bool { + return result2.Friends[i].ID < result2.Friends[j].ID + }) + + AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") +} + +func TestSelectWithUpdateColumn(t *testing.T) { + user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + + time.Sleep(time.Second) + lastUpdatedAt := result.UpdatedAt + DB.Model(&result).Select("Name").Updates(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if lastUpdatedAt.Format(time.RFC3339Nano) == result2.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdatedAt should be changed") + } + + if result2.Name == user.Name || result2.Age != user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestOmitWithUpdateColumn(t *testing.T) { + user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) + DB.Create(&user) + + updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} + + var result User + DB.First(&result, user.ID) + DB.Model(&result).Omit("Name").UpdateColumns(updateValues) + + var result2 User + DB.First(&result2, user.ID) + + if result2.Name != user.Name || result2.Age == user.Age { + t.Errorf("Should only update users with name column") + } +} + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := *GetUser("update_column_skips_association", Config{}) + DB.Create(&user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := uint(100) + user.Account.Number = "new_account_number" + db := DB.Model(&user).UpdateColumns(User{Age: newAge}) + + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) + } + + // Verify that Age now=`newAge`. + result := &User{} + result.ID = user.ID + DB.Preload("Account").First(result) + + if result.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) + } + + if result.Account.Number != user.Account.Number { + t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) + } +} + +func TestUpdatesWithBlankValues(t *testing.T) { + user := *GetUser("updates_with_blank_value", Config{}) + DB.Save(&user) + + var user2 User + user2.ID = user.ID + DB.Model(&user2).Updates(&User{Age: 100}) + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Name || result.Age != 100 { + t.Errorf("user's name should not be updated") + } +} + +func TestUpdatesTableWithIgnoredValues(t *testing.T) { + type ElementWithIgnoredField struct { + Id int64 + Value string + IgnoredField int64 `gorm:"-"` + } + DB.Migrator().DropTable(&ElementWithIgnoredField{}) + DB.AutoMigrate(&ElementWithIgnoredField{}) + + elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} + DB.Save(&elem) + + DB.Model(&ElementWithIgnoredField{}). + Where("id = ?", elem.Id). + Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) + + var result ElementWithIgnoredField + if err := DB.First(&result, elem.Id).Error; err != nil { + t.Errorf("error getting an element from database: %s", err.Error()) + } + + if result.IgnoredField != 0 { + t.Errorf("element's ignored field should not be updated") + } +} + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} + +func TestSave(t *testing.T) { + user := *GetUser("save", Config{}) + DB.Create(&user) + + if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user.Name = "save2" + DB.Save(&user) + + var result User + if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { + t.Fatalf("failed to find updated user") + } + + user2 := *GetUser("save2", Config{}) + DB.Create(&user2) + + time.Sleep(time.Second) + user1UpdatedAt := result.UpdatedAt + user2UpdatedAt := user2.UpdatedAt + var users = []*User{&result, &user2} + DB.Save(&users) + + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + DB.First(&result) + if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user's updated at should be changed after reload, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) + } + + DB.First(&user2) + if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { + t.Fatalf("user2's updated at should be changed after reload, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Save(&user).Statement + if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { + t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) + } +} + +func TestSaveWithPrimaryValue(t *testing.T) { + lang := Language{Code: "save", Name: "save"} + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should create language, rows affected: %v", result.RowsAffected) + } + + var result Language + DB.First(&result, "code = ?", "save") + AssertEqual(t, result, lang) + + lang.Name = "save name2" + if result := DB.Save(&lang); result.RowsAffected != 1 { + t.Errorf("should update language") + } + + var result2 Language + DB.First(&result2, "code = ?", "save") + AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go new file mode 100644 index 00000000..0ba8b9f0 --- /dev/null +++ b/tests/upsert_test.go @@ -0,0 +1,276 @@ +package tests_test + +import ( + "testing" + "time" + + "gorm.io/gorm/clause" + . "gorm.io/gorm/utils/tests" +) + +func TestUpsert(t *testing.T) { + lang := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + lang2 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var langs []Language + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } + + lang3 := Language{Code: "upsert", Name: "Upsert"} + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), + }).Create(&lang3).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if langs[0].Name != "upsert-new" { + t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) + } + + lang = Language{Code: "upsert", Name: "Upsert-Newname"} + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var result Language + if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { + t.Fatalf("failed to upsert, got name %v", result.Name) + } +} + +func TestUpsertSlice(t *testing.T) { + langs := []Language{ + {Code: "upsert-slice1", Name: "Upsert-slice1"}, + {Code: "upsert-slice2", Name: "Upsert-slice2"}, + {Code: "upsert-slice3", Name: "Upsert-slice3"}, + } + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + + var langs2 []Language + if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs2) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs2) + } + + DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) + var langs3 []Language + if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(langs3) != 3 { + t.Errorf("should only find only 3 languages, but got %+v", langs3) + } + + for idx, lang := range langs { + lang.Name = lang.Name + "_new" + langs[idx] = lang + } + + if err := DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "code"}}, + DoUpdates: clause.AssignmentColumns([]string{"name"}), + }).Create(&langs).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + for _, lang := range langs { + var results []Language + if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { + t.Errorf("no error should happen when find languages with code, but got %v", err) + } else if len(results) != 1 { + t.Errorf("should only find only 1 languages, but got %+v", langs) + } else if results[0].Name != lang.Name { + t.Errorf("should update name on conflict, but got name %+v", results[0].Name) + } + } +} + +func TestUpsertWithSave(t *testing.T) { + langs := []Language{ + {Code: "upsert-save-1", Name: "Upsert-save-1"}, + {Code: "upsert-save-2", Name: "Upsert-save-2"}, + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to create, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + for idx, lang := range langs { + lang.Name += "_new" + langs[idx] = lang + } + + if err := DB.Save(&langs).Error; err != nil { + t.Errorf("Failed to upsert, got error %v", err) + } + + for _, lang := range langs { + var result Language + if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + t.Errorf("Failed to query lang, got error %v", err) + } else { + AssertEqual(t, result, lang) + } + } + + // lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result Language + // if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result, lang) + // } + + // lang.Name += "_new" + // if err := DB.Save(&lang).Error; err != nil { + // t.Errorf("Failed to create, got error %v", err) + // } + + // var result2 Language + // if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { + // t.Errorf("Failed to query lang, got error %v", err) + // } else { + // AssertEqual(t, result2, lang) + // } +} + +func TestFindOrInitialize(t *testing.T) { + var user1, user2, user3, user4, user5, user6 User + if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + + if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) + if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { + t.Errorf("user should be initialized with search value") + } + + DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) + if user3.Name != "find or init 2" || user3.ID != 0 { + t.Errorf("user should be initialized with inline search value") + } + + DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and attrs") + } + + DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) + if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { + t.Errorf("user should be initialized with search value and assign attrs") + } + + DB.Save(&User{Name: "find or init", Age: 33}) + DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { + t.Errorf("user should be found with FirstOrInit") + } + + DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) + if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } +} + +func TestFindOrCreate(t *testing.T) { + var user1, user2, user3, user4, user5, user6, user7, user8 User + if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil { + t.Errorf("no error should happen when FirstOrInit, but got %v", err) + } + + if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) + if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 { + t.Errorf("user should be created with search value") + } + + DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) + if user3.Name != "find or create 2" || user3.ID == 0 { + t.Errorf("user should be created with inline search value") + } + + DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) + if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and attrs") + } + + updatedAt1 := user4.UpdatedAt + DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + + if user4.Age != 55 { + t.Errorf("Failed to set change to 55, got %v", user4.Age) + } + + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { + t.Errorf("UpdateAt should be changed when update values with assign") + } + + DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) + if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 { + t.Errorf("user should be created with search value and assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) + if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 { + t.Errorf("user should be found and not initialized by Attrs") + } + + DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) + if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create"}).Find(&user7) + if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 { + t.Errorf("user should be found and updated with assigned attrs") + } + + DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) + if err := DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).Error; err != nil { + t.Errorf("has many association should be saved") + } + + if err := DB.Where("number = ?", "1231231231").First(&Account{}).Error; err != nil { + t.Errorf("belongs to association should be saved") + } +} diff --git a/update_test.go b/update_test.go deleted file mode 100644 index 85d53e5f..00000000 --- a/update_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package gorm_test - -import ( - "testing" - "time" - - "github.com/jinzhu/gorm" -) - -func TestUpdate(t *testing.T) { - product1 := Product{Code: "product1code"} - product2 := Product{Code: "product2code"} - - DB.Save(&product1).Save(&product2).Update("code", "product2newcode") - - if product2.Code != "product2newcode" { - t.Errorf("Record should be updated") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt1 := product1.UpdatedAt - - if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() { - t.Errorf("Product1 should not be updated") - } - - if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode") - - var product4 Product - DB.First(&product4, product1.Id) - if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() { - t.Errorf("Product1's code should be updated") - } - - if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() { - t.Errorf("Product should not be changed to 789") - } - - if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update with CamelCase") - } - - if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil { - t.Error("No error should raise when update_column with CamelCase") - } - - var products []Product - DB.Find(&products) - if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) { - t.Error("RowsAffected should be correct when do batch update") - } - - DB.First(&product4, product4.Id) - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50)) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("Update with expression") - } - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Update with expression should update UpdatedAt") - } -} - -func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) { - animal := Animal{Name: "Ferdinand"} - DB.Save(&animal) - updatedAt1 := animal.UpdatedAt - - DB.Save(&animal).Update("name", "Francis") - - if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated if nothing changed") - } - - var animals []Animal - DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { - t.Error("RowsAffected should be correct when do batch update") - } - - animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) - DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched - DB.First(&animal, animal.Counter) - if animal.Name != "galeone" { - t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) - } - - // When changing a field with a default value, the change must occur - animal.Name = "amazing horse" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "amazing horse" { - t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) - } - - // When changing a field with a default value with blank value - animal.Name = "" - DB.Save(&animal) - DB.First(&animal, animal.Counter) - if animal.Name != "" { - t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) - } -} - -func TestUpdates(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 10} - DB.Save(&product1).Save(&product2) - DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100}) - if product1.Code != "product1newcode" || product1.Price != 100 { - t.Errorf("Record should be updated also with map") - } - - DB.First(&product1, product1.Id) - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - - if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() { - t.Errorf("Product2 should not be updated") - } - - if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() { - t.Errorf("Product1 should be updated") - } - - DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"}) - if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() { - t.Errorf("Product2's code should be updated") - } - - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should be updated if something changed") - } - - if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() { - t.Errorf("product2's code should be updated") - } - - updatedAt4 := product4.UpdatedAt - DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)}) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100 { - t.Errorf("Updates with expression") - } - // product4's UpdatedAt will be reset when updating - if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) { - t.Errorf("Updates with expression should update UpdatedAt") - } -} - -func TestUpdateColumn(t *testing.T) { - product1 := Product{Code: "product1code", Price: 10} - product2 := Product{Code: "product2code", Price: 20} - DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100}) - if product2.Code != "product2newcode" || product2.Price != 100 { - t.Errorf("product 2 should be updated with update column") - } - - var product3 Product - DB.First(&product3, product1.Id) - if product3.Code != "product1code" || product3.Price != 10 { - t.Errorf("product 1 should not be updated") - } - - DB.First(&product2, product2.Id) - updatedAt2 := product2.UpdatedAt - DB.Model(product2).UpdateColumn("code", "update_column_new") - var product4 Product - DB.First(&product4, product2.Id) - if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("updatedAt should not be updated with update column") - } - - DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50")) - var product5 Product - DB.First(&product5, product4.Id) - if product5.Price != product4.Price+100-50 { - t.Errorf("UpdateColumn with expression") - } - if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) { - t.Errorf("UpdateColumn with expression should not update UpdatedAt") - } -} - -func TestSelectWithUpdate(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestSelectWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 || - queryUser.ShippingAddressId != user.ShippingAddressId || - queryUser.CreditCard.ID == user.CreditCard.ID || - len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id { - t.Errorf("Should only update selected relationships") - } -} - -func TestOmitWithUpdate(t *testing.T) { - user := getPreparedUser("omit_user", "omit_with_update") - DB.Create(user) - - var reloadUser User - DB.First(&reloadUser, user.Id) - reloadUser.Name = "new_name" - reloadUser.Age = 50 - reloadUser.BillingAddress = Address{Address1: "New Billing Address"} - reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"} - reloadUser.CreditCard = CreditCard{Number: "987654321"} - reloadUser.Emails = []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - } - reloadUser.Company = Company{Name: "new company"} - - DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships that not omitted") - } -} - -func TestOmitWithUpdateWithMap(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{ - "Name": "new_name", - "Age": 50, - "BillingAddress": Address{Address1: "New Billing Address"}, - "ShippingAddress": Address{Address1: "New ShippingAddress Address"}, - "CreditCard": CreditCard{Number: "987654321"}, - "Emails": []Email{ - {Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"}, - }, - "Company": Company{Name: "new company"}, - } - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues) - - var queryUser User - DB.Preload("BillingAddress").Preload("ShippingAddress"). - Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should only update users with name column") - } - - if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 || - queryUser.ShippingAddressId == user.ShippingAddressId || - queryUser.CreditCard.ID != user.CreditCard.ID || - len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id { - t.Errorf("Should only update relationships not omitted") - } -} - -func TestSelectWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name == user.Name || queryUser.Age != user.Age { - t.Errorf("Should only update users with name column") - } -} - -func TestOmitWithUpdateColumn(t *testing.T) { - user := getPreparedUser("select_user", "select_with_update_map") - DB.Create(user) - - updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} - - var reloadUser User - DB.First(&reloadUser, user.Id) - DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues) - - var queryUser User - DB.First(&queryUser, user.Id) - - if queryUser.Name != user.Name || queryUser.Age == user.Age { - t.Errorf("Should omit name column when update user") - } -} - -func TestUpdateColumnsSkipsAssociations(t *testing.T) { - user := getPreparedUser("update_columns_user", "special_role") - user.Age = 99 - address1 := "first street" - user.BillingAddress = Address{Address1: address1} - DB.Save(user) - - // Update a single field of the user and verify that the changed address is not stored. - newAge := int64(100) - user.BillingAddress.Address1 = "second street" - db := DB.Model(user).UpdateColumns(User{Age: newAge}) - if db.RowsAffected != 1 { - t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) - } - - // Verify that Age now=`newAge`. - freshUser := &User{Id: user.Id} - DB.First(freshUser) - if freshUser.Age != newAge { - t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) - } - - // Verify that user's BillingAddress.Address1 is not changed and is still "first street". - DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) - if freshUser.BillingAddress.Address1 != address1 { - t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) - } -} - -func TestUpdatesWithBlankValues(t *testing.T) { - product := Product{Code: "product1", Price: 10} - DB.Save(&product) - - DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100}) - - var product1 Product - DB.First(&product1, product.Id) - - if product1.Code != "product1" || product1.Price != 100 { - t.Errorf("product's code should not be updated") - } -} - -type ElementWithIgnoredField struct { - Id int64 - Value string - IgnoredField int64 `sql:"-"` -} - -func (e ElementWithIgnoredField) TableName() string { - return "element_with_ignored_field" -} - -func TestUpdatesTableWithIgnoredValues(t *testing.T) { - elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} - DB.Save(&elem) - - DB.Table(elem.TableName()). - Where("id = ?", elem.Id). - // DB.Model(&ElementWithIgnoredField{Id: elem.Id}). - Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) - - var elem1 ElementWithIgnoredField - err := DB.First(&elem1, elem.Id).Error - if err != nil { - t.Errorf("error getting an element from database: %s", err.Error()) - } - - if elem1.IgnoredField != 0 { - t.Errorf("element's ignored field should not be updated") - } -} - -func TestUpdateDecodeVirtualAttributes(t *testing.T) { - var user = User{ - Name: "jinzhu", - IgnoreMe: 88, - } - - DB.Save(&user) - - DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100}) - - if user.IgnoreMe != 100 { - t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks") - } -} diff --git a/utils.go b/utils.go deleted file mode 100644 index 97a3d175..00000000 --- a/utils.go +++ /dev/null @@ -1,285 +0,0 @@ -package gorm - -import ( - "bytes" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "runtime" - "strings" - "sync" - "time" -) - -// NowFunc returns current time, this function is exported in order to be able -// to give the flexibility to the developer to customize it according to their -// needs, e.g: -// gorm.NowFunc = func() time.Time { -// return time.Now().UTC() -// } -var NowFunc = func() time.Time { - return time.Now() -} - -// Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} -var commonInitialismsReplacer *strings.Replacer - -var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm/.*.go`) -var goTestRegexp = regexp.MustCompile(`jinzhu/gorm/.*test.go`) - -func init() { - var commonInitialismsForReplacer []string - for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) - } - commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) -} - -type safeMap struct { - m map[string]string - l *sync.RWMutex -} - -func (s *safeMap) Set(key string, value string) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeMap) Get(key string) string { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newSafeMap() *safeMap { - return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} -} - -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - if i == len(value)-2 && nextCase == upper { - buf.WriteRune('_') - } - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - -// SQL expression -type expr struct { - expr string - args []interface{} -} - -// Expr generate raw SQL expression, for example: -// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} -} - -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Ptr { - reflectValue = reflectValue.Elem() - } - return reflectValue -} - -func toQueryMarks(primaryValues [][]interface{}) string { - var results []string - - for _, primaryValue := range primaryValues { - var marks []string - for range primaryValue { - marks = append(marks, "?") - } - - if len(marks) > 1 { - results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) - } else { - results = append(results, strings.Join(marks, "")) - } - } - return strings.Join(results, ",") -} - -func toQueryCondition(scope *Scope, columns []string) string { - var newColumns []string - for _, column := range columns { - newColumns = append(newColumns, scope.Quote(column)) - } - - if len(columns) > 1 { - return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) - } - return strings.Join(newColumns, ",") -} - -func toQueryValues(values [][]interface{}) (results []interface{}) { - for _, value := range values { - for _, v := range value { - results = append(results, v) - } - } - return -} - -func fileWithLineNum() string { - for i := 2; i < 15; i++ { - _, file, line, ok := runtime.Caller(i) - if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { - return fmt.Sprintf("%v:%v", file, line) - } - } - return "" -} - -func isBlank(value reflect.Value) bool { - switch value.Kind() { - case reflect.String: - return value.Len() == 0 - case reflect.Bool: - return !value.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return value.Uint() == 0 - case reflect.Float32, reflect.Float64: - return value.Float() == 0 - case reflect.Interface, reflect.Ptr: - return value.IsNil() - } - - return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) -} - -func toSearchableMap(attrs ...interface{}) (result interface{}) { - if len(attrs) > 1 { - if str, ok := attrs[0].(string); ok { - result = map[string]interface{}{str: attrs[1]} - } - } else if len(attrs) == 1 { - if attr, ok := attrs[0].(map[string]interface{}); ok { - result = attr - } - - if attr, ok := attrs[0].(interface{}); ok { - result = attr - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return toString(a) == toString(b) -} - -func toString(str interface{}) string { - if values, ok := str.([]interface{}); ok { - var results []string - for _, value := range values { - results = append(results, toString(value)) - } - return strings.Join(results, "_") - } else if bytes, ok := str.([]byte); ok { - return string(bytes) - } else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() { - return fmt.Sprintf("%v", reflectValue.Interface()) - } - return "" -} - -func makeSlice(elemType reflect.Type) interface{} { - if elemType.Kind() == reflect.Slice { - elemType = elemType.Elem() - } - sliceType := reflect.SliceOf(elemType) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() -} - -func strInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} - -// getValueFromFields return given fields's value -func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { - for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { - result := fieldValue.Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func addExtraSpaceIfExist(str string) string { - if str != "" { - return " " + str - } - return "" -} diff --git a/utils/tests/dummy_dialecter.go b/utils/tests/dummy_dialecter.go new file mode 100644 index 00000000..b8452ef9 --- /dev/null +++ b/utils/tests/dummy_dialecter.go @@ -0,0 +1,45 @@ +package tests + +import ( + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +type DummyDialector struct { +} + +func (DummyDialector) Name() string { + return "dummy" +} + +func (DummyDialector) Initialize(*gorm.DB) error { + return nil +} + +func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { + return clause.Expr{SQL: "DEFAULT"} +} + +func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { + return nil +} + +func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (DummyDialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + writer.WriteString(str) + writer.WriteByte('`') +} + +func (DummyDialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + +func (DummyDialector) DataTypeOf(*schema.Field) string { + return "" +} diff --git a/utils/tests/models.go b/utils/tests/models.go new file mode 100644 index 00000000..2c5e71c0 --- /dev/null +++ b/utils/tests/models.go @@ -0,0 +1,60 @@ +package tests + +import ( + "database/sql" + "time" + + "gorm.io/gorm" +) + +// User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) +// He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) +// He speaks many languages (many to many) and has many friends (many to many - single-table) +// His pet also has one Toy (has one - polymorphic) +type User struct { + gorm.Model + Name string + Age uint + Birthday *time.Time + Account Account + Pets []*Pet + Toys []Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company Company + ManagerID *uint + Manager *User + Team []User `gorm:"foreignkey:ManagerID"` + Languages []Language `gorm:"many2many:UserSpeak;"` + Friends []*User `gorm:"many2many:user_friends;"` + Active bool +} + +type Account struct { + gorm.Model + UserID sql.NullInt64 + Number string +} + +type Pet struct { + gorm.Model + UserID *uint + Name string + Toy Toy `gorm:"polymorphic:Owner;"` +} + +type Toy struct { + gorm.Model + Name string + OwnerID string + OwnerType string +} + +type Company struct { + ID int + Name string +} + +type Language struct { + Code string `gorm:"primarykey"` + Name string +} diff --git a/utils/tests/utils.go b/utils/tests/utils.go new file mode 100644 index 00000000..817e4b0b --- /dev/null +++ b/utils/tests/utils.go @@ -0,0 +1,117 @@ +package tests + +import ( + "database/sql/driver" + "fmt" + "go/ast" + "reflect" + "testing" + "time" + + "gorm.io/gorm/utils" +) + +func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { + for _, name := range names { + got := reflect.Indirect(reflect.ValueOf(r)).FieldByName(name).Interface() + expect := reflect.Indirect(reflect.ValueOf(e)).FieldByName(name).Interface() + t.Run(name, func(t *testing.T) { + AssertEqual(t, got, expect) + }) + } +} + +func AssertEqual(t *testing.T, got, expect interface{}) { + if !reflect.DeepEqual(got, expect) { + isEqual := func() { + if curTime, ok := got.(time.Time); ok { + format := "2006-01-02T15:04:05Z07:00" + + if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) { + t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) + } + } else if fmt.Sprint(got) != fmt.Sprint(expect) { + t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) + } + } + + if fmt.Sprint(got) == fmt.Sprint(expect) { + return + } + + if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if valuer, ok := got.(driver.Valuer); ok { + got, _ = valuer.Value() + } + + if valuer, ok := expect.(driver.Valuer); ok { + expect, _ = valuer.Value() + } + + if got != nil { + got = reflect.Indirect(reflect.ValueOf(got)).Interface() + } + + if expect != nil { + expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() + } + + if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { + t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) + return + } + + if reflect.ValueOf(got).Kind() == reflect.Slice { + if reflect.ValueOf(expect).Kind() == reflect.Slice { + if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { + for i := 0; i < reflect.ValueOf(got).Len(); i++ { + name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) + t.Run(name, func(t *testing.T) { + AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) + }) + } + } else { + name := reflect.ValueOf(got).Type().Elem().Name() + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) + } + return + } + } + + if reflect.ValueOf(got).Kind() == reflect.Struct { + if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { + exported := false + for i := 0; i < reflect.ValueOf(got).NumField(); i++ { + if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { + exported = true + field := reflect.ValueOf(got).Field(i) + t.Run(fieldStruct.Name, func(t *testing.T) { + AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) + }) + } + } + + if exported { + return + } + } + } + + if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { + got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() + isEqual() + } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { + expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() + isEqual() + } + } +} + +func Now() *time.Time { + now := time.Now() + return &now +} diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..ecba7fb9 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,113 @@ +package utils + +import ( + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "runtime" + "strconv" + "strings" + "unicode" +) + +var gormSourceDir string + +func init() { + _, file, _, _ := runtime.Caller(0) + gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") +} + +func FileWithLineNum() string { + for i := 2; i < 15; i++ { + _, file, line, ok := runtime.Caller(i) + + if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { + return file + ":" + strconv.FormatInt(int64(line), 10) + } + } + return "" +} + +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' +} + +func CheckTruth(val interface{}) bool { + if v, ok := val.(bool); ok { + return v + } + + if v, ok := val.(string); ok { + v = strings.ToLower(v) + return v != "false" + } + + return !reflect.ValueOf(val).IsZero() +} + +func ToStringKey(values ...interface{}) string { + results := make([]string, len(values)) + + for idx, value := range values { + if valuer, ok := value.(driver.Valuer); ok { + value, _ = valuer.Value() + } + + switch v := value.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + } + } + + return strings.Join(results, "_") +} + +func AssertEqual(src, dst interface{}) bool { + if !reflect.DeepEqual(src, dst) { + if valuer, ok := src.(driver.Valuer); ok { + src, _ = valuer.Value() + } + + if valuer, ok := dst.(driver.Valuer); ok { + dst, _ = valuer.Value() + } + + return reflect.DeepEqual(src, dst) + } + return true +} + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +} diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +} diff --git a/utils_test.go b/utils_test.go deleted file mode 100644 index 152296d2..00000000 --- a/utils_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package gorm_test - -import ( - "testing" - - "github.com/jinzhu/gorm" -) - -func TestToDBNameGenerateFriendlyName(t *testing.T) { - var maps = map[string]string{ - "": "", - "X": "x", - "ThisIsATest": "this_is_a_test", - "PFAndESI": "pf_and_esi", - "AbcAndJkl": "abc_and_jkl", - "EmployeeID": "employee_id", - "SKU_ID": "sku_id", - "FieldX": "field_x", - "HTTPAndSMTP": "http_and_smtp", - "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", - "UUID": "uuid", - "HTTPURL": "http_url", - "HTTP_URL": "http_url", - "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", - } - - for key, value := range maps { - if gorm.ToDBName(key) != value { - t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key)) - } - } -} diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index ff6fb17c..00000000 --- a/wercker.yml +++ /dev/null @@ -1,53 +0,0 @@ -# use the default golang container from Docker Hub -box: golang - -services: - - id: mariadb:10.0 - env: - MYSQL_DATABASE: gorm - MYSQL_USER: gorm - MYSQL_PASSWORD: gorm - MYSQL_RANDOM_ROOT_PASSWORD: "yes" - - id: postgres - env: - POSTGRES_USER: gorm - POSTGRES_PASSWORD: gorm - POSTGRES_DB: gorm - -# The steps that will be executed in the build pipeline -build: - # The steps that will be executed on build - steps: - # Sets the go workspace and places you package - # at the right place in the workspace tree - - setup-go-workspace - - # Gets the dependencies - - script: - name: go get - code: | - cd $WERCKER_SOURCE_DIR - go version - go get -t ./... - - # Build the project - - script: - name: go build - code: | - go build ./... - - # Test the project - - script: - name: test sqlite - code: | - go test ./... - - - script: - name: test mysql - code: | - GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./... - - - script: - name: test postgres - code: | - GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...