Merge remote-tracking branch 'gorm/master'
This commit is contained in:
commit
cd42f9571b
@ -1,11 +0,0 @@
|
||||
---
|
||||
engines:
|
||||
gofmt:
|
||||
enabled: true
|
||||
govet:
|
||||
enabled: true
|
||||
golint:
|
||||
enabled: true
|
||||
ratings:
|
||||
paths:
|
||||
- "**.go"
|
||||
5
.github/FUNDING.yml
vendored
Normal file
5
.github/FUNDING.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [jinzhu]
|
||||
patreon: jinzhu
|
||||
open_collective: gorm
|
||||
54
.github/ISSUE_TEMPLATE.md
vendored
54
.github/ISSUE_TEMPLATE.md
vendored
@ -1,54 +0,0 @@
|
||||
Before posting a bug report about a problem, please try to verify that it is a bug and that it has not been reported already, please apply corresponding GitHub labels to the issue, for feature requests, please apply `type:feature`.
|
||||
|
||||
DON'T post usage related questions, ask in https://gitter.im/jinzhu/gorm or http://stackoverflow.com/questions/tagged/go-gorm,
|
||||
|
||||
Please answer these questions before submitting your issue. Thanks!
|
||||
|
||||
|
||||
|
||||
### What version of Go are you using (`go version`)?
|
||||
|
||||
|
||||
### Which database and its version are you using?
|
||||
|
||||
|
||||
### What did you do?
|
||||
|
||||
Please provide a complete runnable program to reproduce your issue.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
db, err = gorm.Open("sqlite3", "test.db")
|
||||
// Please use below username, password as your database's account for the script.
|
||||
// db, err = gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db, err = gorm.Open("mysql", "gorm:gorm@/dbname?charset=utf8&parseTime=True")
|
||||
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:1433?database=gorm")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db.LogMode(true)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// your code here
|
||||
|
||||
if /* failure condition */ {
|
||||
fmt.Println("failed")
|
||||
} else {
|
||||
fmt.Println("success")
|
||||
}
|
||||
}
|
||||
```
|
||||
14
.github/PULL_REQUEST_TEMPLATE.md
vendored
14
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,14 +0,0 @@
|
||||
Make sure these boxes checked before submitting your pull request.
|
||||
|
||||
- [] Do only one thing
|
||||
- [] No API-breaking changes
|
||||
- [] New code/logic commented & tested
|
||||
- [] Write good commit message, try to squash your commits into a single one
|
||||
- [] Run `./build.sh` in `gh-pages` branch for document changes
|
||||
|
||||
For significant changes like big bug fixes, new features, please open an issue to make a agreement on an implementation design/plan first before starting it.
|
||||
|
||||
Thank you.
|
||||
|
||||
|
||||
### What did this pull request do?
|
||||
166
.github/labels.json
vendored
Normal file
166
.github/labels.json
vendored
Normal file
@ -0,0 +1,166 @@
|
||||
{
|
||||
"labels": {
|
||||
"critical": {
|
||||
"name": "type:critical",
|
||||
"colour": "#E84137",
|
||||
"description": "critical questions"
|
||||
},
|
||||
"question": {
|
||||
"name": "type:question",
|
||||
"colour": "#EDEDED",
|
||||
"description": "general questions"
|
||||
},
|
||||
"feature": {
|
||||
"name": "type:feature_request",
|
||||
"colour": "#43952A",
|
||||
"description": "feature request"
|
||||
},
|
||||
"invalid_question": {
|
||||
"name": "type:invalid question",
|
||||
"colour": "#CF2E1F",
|
||||
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
|
||||
},
|
||||
"with_playground": {
|
||||
"name": "type:with reproduction steps",
|
||||
"colour": "#00ff00",
|
||||
"description": "with reproduction steps"
|
||||
},
|
||||
"without_playground": {
|
||||
"name": "type:missing reproduction steps",
|
||||
"colour": "#CF2E1F",
|
||||
"description": "missing reproduction steps"
|
||||
},
|
||||
"has_pr": {
|
||||
"name": "type:has pull request",
|
||||
"colour": "#43952A",
|
||||
"description": "has pull request"
|
||||
},
|
||||
"not_tested": {
|
||||
"name": "type:not tested",
|
||||
"colour": "#CF2E1F",
|
||||
"description": "not tested"
|
||||
},
|
||||
"tested": {
|
||||
"name": "type:tested",
|
||||
"colour": "#00ff00",
|
||||
"description": "tested"
|
||||
},
|
||||
"breaking_change": {
|
||||
"name": "type:breaking change",
|
||||
"colour": "#CF2E1F",
|
||||
"description": "breaking change"
|
||||
}
|
||||
},
|
||||
"issue": {
|
||||
"with_playground": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s"
|
||||
}
|
||||
]
|
||||
},
|
||||
"critical": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/(critical|urgent)/i"
|
||||
},
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/(critical|urgent)/i"
|
||||
}
|
||||
]
|
||||
},
|
||||
"question": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/question/i"
|
||||
},
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/question/i"
|
||||
}
|
||||
]
|
||||
},
|
||||
"feature": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/feature/i"
|
||||
},
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/Describe the feature/i"
|
||||
}
|
||||
]
|
||||
},
|
||||
"without_playground": {
|
||||
"requires": 6,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s"
|
||||
},
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/^((?!question).)*$/s"
|
||||
},
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/^((?!question).)*$/is"
|
||||
},
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/^((?!Describe the feature).)*$/is"
|
||||
},
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/^((?!critical|urgent).)*$/s"
|
||||
},
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/^((?!critical|urgent).)*$/s"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"pr": {
|
||||
"critical": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/(critical|urgent)/i"
|
||||
},
|
||||
{
|
||||
"type": "titleMatches",
|
||||
"pattern": "/(critical|urgent)/i"
|
||||
}
|
||||
]
|
||||
},
|
||||
"not_tested": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/\\[\\] Tested/"
|
||||
}
|
||||
]
|
||||
},
|
||||
"breaking_change": {
|
||||
"requires": 1,
|
||||
"conditions": [
|
||||
{
|
||||
"type": "descriptionMatches",
|
||||
"pattern": "/\\[\\] Non breaking API changes/"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
22
.github/workflows/invalid_question.yml
vendored
Normal file
22
.github/workflows/invalid_question.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: "Close invalid questions issues"
|
||||
on:
|
||||
schedule:
|
||||
- cron: "*/10 * * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v3.0.7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
stale-issue-label: "status:stale"
|
||||
days-before-stale: 0
|
||||
days-before-close: 2
|
||||
remove-stale-when-updated: true
|
||||
only-labels: "type:invalid question"
|
||||
|
||||
19
.github/workflows/labeler.yml
vendored
Normal file
19
.github/workflows/labeler.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: "Issue Labeler"
|
||||
on:
|
||||
issues:
|
||||
types: [opened, edited, reopened]
|
||||
pull_request:
|
||||
types: [opened, edited, reopened]
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
name: Label issues and pull requests
|
||||
steps:
|
||||
- name: check out
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: labeler
|
||||
uses: jinzhu/super-labeler-action@develop
|
||||
with:
|
||||
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
|
||||
21
.github/workflows/missing_playground.yml
vendored
Normal file
21
.github/workflows/missing_playground.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
name: "Close Missing Playground issues"
|
||||
on:
|
||||
schedule:
|
||||
- cron: "*/10 * * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v3.0.7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
stale-issue-label: "status:stale"
|
||||
days-before-stale: 0
|
||||
days-before-close: 2
|
||||
remove-stale-when-updated: true
|
||||
only-labels: "type:missing reproduction steps"
|
||||
11
.github/workflows/reviewdog.yml
vendored
Normal file
11
.github/workflows/reviewdog.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
name: reviewdog
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
golangci-lint:
|
||||
name: runner / golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v1
|
||||
- name: golangci-lint
|
||||
uses: reviewdog/action-golangci-lint@v1
|
||||
22
.github/workflows/stale.yml
vendored
Normal file
22
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: "Stale"
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v3.0.7
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
|
||||
days-before-stale: 60
|
||||
days-before-close: 30
|
||||
stale-issue-label: "status:stale"
|
||||
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
|
||||
stale-pr-label: 'status:stale'
|
||||
exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request'
|
||||
166
.github/workflows/tests.yml
vendored
Normal file
166
.github/workflows/tests.yml
vendored
Normal file
@ -0,0 +1,166 @@
|
||||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches-ignore:
|
||||
- 'gh-pages'
|
||||
pull_request:
|
||||
branches-ignore:
|
||||
- 'gh-pages'
|
||||
|
||||
jobs:
|
||||
# Label of the container job
|
||||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.16', '1.15', '1.14']
|
||||
platform: [ubuntu-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GORM_DIALECT=sqlite ./tests/tests_all.sh
|
||||
|
||||
mysql:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
|
||||
go: ['1.16', '1.15', '1.14']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
ports:
|
||||
- 9910:3306
|
||||
options: >-
|
||||
--health-cmd "mysqladmin ping -ugorm -pgorm"
|
||||
--health-interval 10s
|
||||
--health-start-period 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
postgres:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
|
||||
go: ['1.16', '1.15', '1.14']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
POSTGRES_PASSWORD: gorm
|
||||
POSTGRES_USER: gorm
|
||||
POSTGRES_DB: gorm
|
||||
TZ: Asia/Shanghai
|
||||
ports:
|
||||
- 9920:5432
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
|
||||
|
||||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.16', '1.15', '1.14']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mssql:
|
||||
image: mcmoe/mssqldocker:latest
|
||||
env:
|
||||
ACCEPT_EULA: Y
|
||||
SA_PASSWORD: LoremIpsum86
|
||||
MSSQL_DB: gorm
|
||||
MSSQL_USER: gorm
|
||||
MSSQL_PASSWORD: LoremIpsum86
|
||||
ports:
|
||||
- 9930:1433
|
||||
options: >-
|
||||
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
|
||||
--health-start-period 10s
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,2 +1,5 @@
|
||||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
|
||||
52
README.md
52
README.md
@ -2,49 +2,41 @@
|
||||
|
||||
The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
|
||||
[](https://godoc.org/github.com/jinzhu/gorm)
|
||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[](https://github.com/go-gorm/gorm/actions)
|
||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
## Overview
|
||||
|
||||
* Full-Featured ORM (almost)
|
||||
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
|
||||
* Callbacks (Before/After Create/Save/Update/Delete/Find)
|
||||
* Preloading (eager loading)
|
||||
* Transactions
|
||||
* Full-Featured ORM
|
||||
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
|
||||
* Hooks (Before/After Create/Save/Update/Delete/Find)
|
||||
* Eager loading with `Preload`, `Joins`
|
||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statement Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||
* Composite Primary Key
|
||||
* SQL Builder
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
* Extendable, write Plugins based on GORM callbacks
|
||||
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
|
||||
* Every feature comes with tests
|
||||
* Developer Friendly
|
||||
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm)
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
|
||||
## Upgrading To V1.0
|
||||
## Contributing
|
||||
|
||||
* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html)
|
||||
|
||||
## Supporting the project
|
||||
|
||||
[](http://patreon.com/jinzhu)
|
||||
|
||||
## Author
|
||||
|
||||
**jinzhu**
|
||||
|
||||
* <http://github.com/jinzhu>
|
||||
* <wosmvp@gmail.com>
|
||||
* <http://twitter.com/zhangjinzhu>
|
||||
|
||||
## Contributors
|
||||
|
||||
https://github.com/jinzhu/gorm/graphs/contributors
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## License
|
||||
|
||||
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License).
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||
|
||||
790
association.go
790
association.go
@ -1,375 +1,513 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Error error
|
||||
scope *Scope
|
||||
column string
|
||||
field *Field
|
||||
}
|
||||
|
||||
// Find find out all related associations
|
||||
func (association *Association) Find(value interface{}) *Association {
|
||||
association.scope.related(value, association.column)
|
||||
return association.setErr(association.scope.db.Error)
|
||||
}
|
||||
func (db *DB) Association(column string) *Association {
|
||||
association := &Association{DB: db}
|
||||
table := db.Statement.Table
|
||||
|
||||
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||
func (association *Association) Append(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
||||
db.Statement.Table = table
|
||||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||
|
||||
if association.Relationship == nil {
|
||||
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
|
||||
}
|
||||
|
||||
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||
return association.Replace(values...)
|
||||
}
|
||||
return association.saveAssociations(values...)
|
||||
}
|
||||
|
||||
// Replace replace current associations with new one
|
||||
func (association *Association) Replace(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
}
|
||||
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
// Append new values
|
||||
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||
association.saveAssociations(values...)
|
||||
|
||||
// Belongs To
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// Set foreign key to be null when clearing value (length equals 0)
|
||||
if len(values) == 0 {
|
||||
// Set foreign key to be nil
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
} else {
|
||||
// Polymorphic Relations
|
||||
if relationship.PolymorphicDBName != "" {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
|
||||
association.Error = err
|
||||
}
|
||||
|
||||
// Delete Relations except new created
|
||||
if len(values) > 0 {
|
||||
var associationForeignFieldNames, associationForeignDBNames []string
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, get association fields name from association foreign keys
|
||||
associationScope := scope.New(reflect.New(field.Type()).Interface())
|
||||
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(dbName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If has one/many relations, use primary keys
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
|
||||
|
||||
if len(newPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// if many to many relations, delete related relations from join table
|
||||
var sourceForeignFieldNames []string
|
||||
|
||||
for _, dbName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
if association.Error != nil {
|
||||
return association
|
||||
}
|
||||
|
||||
var (
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
field = association.field.Field
|
||||
newDB = scope.NewDB()
|
||||
)
|
||||
|
||||
if len(values) == 0 {
|
||||
return association
|
||||
}
|
||||
|
||||
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||
}
|
||||
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
// source value's foreign keys
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// get association's foreign fields name
|
||||
var associationScope = scope.New(reflect.New(field.Type()).Interface())
|
||||
var associationForeignFieldNames []string
|
||||
for _, associationDBName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := associationScope.FieldByName(associationDBName); ok {
|
||||
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// association value's foreign keys
|
||||
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
|
||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||
} else {
|
||||
var foreignKeyMap = map[string]interface{}{}
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
foreignKeyMap[foreignKey] = nil
|
||||
}
|
||||
|
||||
if relationship.Kind == "belongs_to" {
|
||||
// find with deleting relation's foreign keys
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// set foreign key to be null if there are some records affected
|
||||
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||
if results.RowsAffected > 0 {
|
||||
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||
}
|
||||
} else {
|
||||
association.setErr(results.Error)
|
||||
}
|
||||
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||
// find all relations
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
|
||||
// only include those deleting relations
|
||||
newDB = newDB.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||
toQueryValues(deletingPrimaryKeys)...,
|
||||
)
|
||||
|
||||
// set matched relation's foreign key to be null
|
||||
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove deleted records from source's field
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
if field.Kind() == reflect.Slice {
|
||||
leftValues := reflect.Zero(field.Type())
|
||||
|
||||
for i := 0; i < field.Len(); i++ {
|
||||
reflectValue := field.Index(i)
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||
var isDeleted = false
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
isDeleted = true
|
||||
break
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
}
|
||||
}
|
||||
if !isDeleted {
|
||||
leftValues = reflect.Append(leftValues, reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
association.field.Set(leftValues)
|
||||
} else if field.Kind() == reflect.Struct {
|
||||
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||
for _, pk := range deletingPrimaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
association.field.Set(reflect.Zero(field.Type()))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// Clear remove relationship between source & current associations, won't delete those associations
|
||||
func (association *Association) Clear() *Association {
|
||||
func (association *Association) Append(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
if len(values) > 0 {
|
||||
association.Error = association.Replace(values...)
|
||||
}
|
||||
default:
|
||||
association.saveAssociation( /*clear*/ false, values...)
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
updateMap := map[string]interface{}{}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
}
|
||||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateMap = map[string]interface{}{}
|
||||
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
|
||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
|
||||
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
|
||||
tx.Not(clause.IN{Column: column, Values: values})
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
||||
tx.Where(clause.IN{Column: column, Values: values})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||
}
|
||||
|
||||
association.Error = tx.Delete(modelValue).Error
|
||||
}
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Delete(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
rel = association.Relationship
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateAttrs = map[string]interface{}{}
|
||||
conds []clause.Expression
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateAttrs[ref.ForeignKey.DBName] = nil
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
|
||||
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
case schema.HasOne, schema.HasMany:
|
||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
|
||||
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
// clean up deleted values's foreign key
|
||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
|
||||
|
||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(data); !zero {
|
||||
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
|
||||
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
|
||||
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
|
||||
for i := 0; i < fieldValue.Len(); i++ {
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
|
||||
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
association.Error = rel.Field.Set(data, validFieldValues.Interface())
|
||||
case reflect.Struct:
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(fieldValue)
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
|
||||
if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if rel.JoinTable == nil {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
|
||||
association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
} else {
|
||||
association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
|
||||
}
|
||||
case reflect.Struct:
|
||||
cleanUpDeletedRelations(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Clear() error {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
// Count return the count of current associations
|
||||
func (association *Association) Count() int {
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Count(&count).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type assignBack struct {
|
||||
Source reflect.Value
|
||||
Index int
|
||||
Dest reflect.Value
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
||||
var (
|
||||
count = 0
|
||||
relationship = association.field.Relationship
|
||||
scope = association.scope
|
||||
fieldValue = association.field.Field.Interface()
|
||||
query = scope.DB()
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
assignBacks []assignBack // assign association values back to arguments after save
|
||||
)
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||
toQueryValues(primaryKeys)...,
|
||||
)
|
||||
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
|
||||
}
|
||||
}
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
|
||||
if clear {
|
||||
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(
|
||||
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
|
||||
relationship.PolymorphicValue,
|
||||
)
|
||||
appendToFieldValues := func(ev reflect.Value) {
|
||||
if ev.Type().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev)
|
||||
} else if ev.Type().Elem().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev.Elem())
|
||||
} else {
|
||||
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
|
||||
}
|
||||
|
||||
if err := query.Model(fieldValue).Count(&count).Error; err != nil {
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
|
||||
}
|
||||
}
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToFieldValues(rv.Addr())
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
selectedSaveColumns := []string{association.Relationship.Name}
|
||||
omitColumns := []string{}
|
||||
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, association.Relationship.Name) {
|
||||
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
|
||||
columnName = name
|
||||
}
|
||||
} else if strings.HasPrefix(name, clause.Associations) {
|
||||
columnName = name
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selectedSaveColumns = append(selectedSaveColumns, columnName)
|
||||
} else {
|
||||
omitColumns = append(omitColumns, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
|
||||
}
|
||||
}
|
||||
|
||||
associationDB := association.DB.Session(&Session{}).Model(nil)
|
||||
if !association.DB.FullSaveAssociations {
|
||||
associationDB.Select(selectedSaveColumns)
|
||||
}
|
||||
if len(omitColumns) > 0 {
|
||||
associationDB.Omit(omitColumns...)
|
||||
}
|
||||
associationDB = associationDB.Session(&Session{})
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(values) != reflectValue.Len() {
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
|
||||
if association.Relationship.JoinTable == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
association.Error = ErrInvalidValueOfLength
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
|
||||
if association.Relationship.JoinTable == nil && association.Error == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
|
||||
}
|
||||
}
|
||||
|
||||
for _, assignBack := range assignBacks {
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
|
||||
if assignBack.Index > 0 {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
|
||||
} else {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// saveAssociations save passed values as associations
|
||||
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||
func (association *Association) buildCondition() *DB {
|
||||
var (
|
||||
scope = association.scope
|
||||
field = association.field
|
||||
relationship = field.Relationship
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
saveAssociation := func(reflectValue reflect.Value) {
|
||||
// value has to been pointer
|
||||
if reflectValue.Kind() != reflect.Ptr {
|
||||
reflectPtr := reflect.New(reflectValue.Type())
|
||||
reflectPtr.Elem().Set(reflectValue)
|
||||
reflectValue = reflectPtr
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
|
||||
// value has to been saved for many2many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign Fields
|
||||
var fieldType = field.Field.Type()
|
||||
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||
if reflectValue.Type().AssignableTo(fieldType) {
|
||||
field.Set(reflectValue)
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||
// if field's type is struct, then need to set value back to argument after save
|
||||
setFieldBackToValue = true
|
||||
field.Set(reflectValue.Elem())
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||
field.Set(reflect.Append(field.Field, reflectValue))
|
||||
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||
// if field's type is slice of struct, then need to set value back to argument after save
|
||||
setSliceFieldBackToValue = true
|
||||
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||
|
||||
if setFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field)
|
||||
} else if setSliceFieldBackToValue {
|
||||
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||
}
|
||||
}
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||
if indirectReflectValue.Kind() == reflect.Struct {
|
||||
saveAssociation(reflectValue)
|
||||
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||
saveAssociation(indirectReflectValue.Index(i))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("invalid value type"))
|
||||
}
|
||||
}
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
if err != nil {
|
||||
association.Error = err
|
||||
}
|
||||
return association
|
||||
return tx
|
||||
}
|
||||
|
||||
@ -1,907 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestBelongsTo(t *testing.T) {
|
||||
post := Post{
|
||||
Title: "post belongs to",
|
||||
Body: "body belongs to",
|
||||
Category: Category{Name: "Category 1"},
|
||||
MainCategory: Category{Name: "Main Category 1"},
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Error("Got errors when save post", err)
|
||||
}
|
||||
|
||||
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||
t.Errorf("Category's primary key should be updated")
|
||||
}
|
||||
|
||||
if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 {
|
||||
t.Errorf("post's foreign key should be updated")
|
||||
}
|
||||
|
||||
// Query
|
||||
var category1 Category
|
||||
DB.Model(&post).Association("Category").Find(&category1)
|
||||
if category1.Name != "Category 1" {
|
||||
t.Errorf("Query belongs to relations with Association")
|
||||
}
|
||||
|
||||
var mainCategory1 Category
|
||||
DB.Model(&post).Association("MainCategory").Find(&mainCategory1)
|
||||
if mainCategory1.Name != "Main Category 1" {
|
||||
t.Errorf("Query belongs to relations with Association")
|
||||
}
|
||||
|
||||
var category11 Category
|
||||
DB.Model(&post).Related(&category11)
|
||||
if category11.Name != "Category 1" {
|
||||
t.Errorf("Query belongs to relations with Related")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||
t.Errorf("Post's category count should be 1")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("MainCategory").Count() != 1 {
|
||||
t.Errorf("Post's main category count should be 1")
|
||||
}
|
||||
|
||||
// Append
|
||||
var category2 = Category{
|
||||
Name: "Category 2",
|
||||
}
|
||||
DB.Model(&post).Association("Category").Append(&category2)
|
||||
|
||||
if category2.ID == 0 {
|
||||
t.Errorf("Category should has ID when created with Append")
|
||||
}
|
||||
|
||||
var category21 Category
|
||||
DB.Model(&post).Related(&category21)
|
||||
|
||||
if category21.Name != "Category 2" {
|
||||
t.Errorf("Category should be updated with Append")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||
t.Errorf("Post's category count should be 1")
|
||||
}
|
||||
|
||||
// Replace
|
||||
var category3 = Category{
|
||||
Name: "Category 3",
|
||||
}
|
||||
DB.Model(&post).Association("Category").Replace(&category3)
|
||||
|
||||
if category3.ID == 0 {
|
||||
t.Errorf("Category should has ID when created with Replace")
|
||||
}
|
||||
|
||||
var category31 Category
|
||||
DB.Model(&post).Related(&category31)
|
||||
if category31.Name != "Category 3" {
|
||||
t.Errorf("Category should be updated with Replace")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||
t.Errorf("Post's category count should be 1")
|
||||
}
|
||||
|
||||
// Delete
|
||||
DB.Model(&post).Association("Category").Delete(&category2)
|
||||
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||
t.Errorf("Should not delete any category when Delete a unrelated Category")
|
||||
}
|
||||
|
||||
if post.Category.Name == "" {
|
||||
t.Errorf("Post's category should not be reseted when Delete a unrelated Category")
|
||||
}
|
||||
|
||||
DB.Model(&post).Association("Category").Delete(&category3)
|
||||
|
||||
if post.Category.Name != "" {
|
||||
t.Errorf("Post's category should be reseted after Delete")
|
||||
}
|
||||
|
||||
var category41 Category
|
||||
DB.Model(&post).Related(&category41)
|
||||
if category41.Name != "" {
|
||||
t.Errorf("Category should be deleted with Delete")
|
||||
}
|
||||
|
||||
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||
t.Errorf("Post's category count should be 0 after Delete, but got %v", count)
|
||||
}
|
||||
|
||||
// Clear
|
||||
DB.Model(&post).Association("Category").Append(&Category{
|
||||
Name: "Category 2",
|
||||
})
|
||||
|
||||
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||
t.Errorf("Should find category after append")
|
||||
}
|
||||
|
||||
if post.Category.Name == "" {
|
||||
t.Errorf("Post's category should has value after Append")
|
||||
}
|
||||
|
||||
DB.Model(&post).Association("Category").Clear()
|
||||
|
||||
if post.Category.Name != "" {
|
||||
t.Errorf("Post's category should be cleared after Clear")
|
||||
}
|
||||
|
||||
if !DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||
t.Errorf("Should not find any category after Clear")
|
||||
}
|
||||
|
||||
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||
t.Errorf("Post's category count should be 0 after Clear, but got %v", count)
|
||||
}
|
||||
|
||||
// Check Association mode with soft delete
|
||||
category6 := Category{
|
||||
Name: "Category 6",
|
||||
}
|
||||
DB.Model(&post).Association("Category").Append(&category6)
|
||||
|
||||
if count := DB.Model(&post).Association("Category").Count(); count != 1 {
|
||||
t.Errorf("Post's category count should be 1 after Append, but got %v", count)
|
||||
}
|
||||
|
||||
DB.Delete(&category6)
|
||||
|
||||
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||
t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count)
|
||||
}
|
||||
|
||||
if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil {
|
||||
t.Errorf("Post's category is not findable after Delete")
|
||||
}
|
||||
|
||||
if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 {
|
||||
t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count)
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil {
|
||||
t.Errorf("Post's category should be findable when query with Unscoped, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBelongsToOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
|
||||
ProfileRefer int
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "belongs_to" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBelongsToOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
|
||||
ProfileID int
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "belongs_to" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOne(t *testing.T) {
|
||||
user := User{
|
||||
Name: "has one",
|
||||
CreditCard: CreditCard{Number: "411111111111"},
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Error("Got errors when save user", err.Error())
|
||||
}
|
||||
|
||||
if user.CreditCard.UserId.Int64 == 0 {
|
||||
t.Errorf("CreditCard's foreign key should be updated")
|
||||
}
|
||||
|
||||
// Query
|
||||
var creditCard1 CreditCard
|
||||
DB.Model(&user).Related(&creditCard1)
|
||||
|
||||
if creditCard1.Number != "411111111111" {
|
||||
t.Errorf("Query has one relations with Related")
|
||||
}
|
||||
|
||||
var creditCard11 CreditCard
|
||||
DB.Model(&user).Association("CreditCard").Find(&creditCard11)
|
||||
|
||||
if creditCard11.Number != "411111111111" {
|
||||
t.Errorf("Query has one relations with Related")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||
t.Errorf("User's credit card count should be 1")
|
||||
}
|
||||
|
||||
// Append
|
||||
var creditcard2 = CreditCard{
|
||||
Number: "411111111112",
|
||||
}
|
||||
DB.Model(&user).Association("CreditCard").Append(&creditcard2)
|
||||
|
||||
if creditcard2.ID == 0 {
|
||||
t.Errorf("Creditcard should has ID when created with Append")
|
||||
}
|
||||
|
||||
var creditcard21 CreditCard
|
||||
DB.Model(&user).Related(&creditcard21)
|
||||
if creditcard21.Number != "411111111112" {
|
||||
t.Errorf("CreditCard should be updated with Append")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||
t.Errorf("User's credit card count should be 1")
|
||||
}
|
||||
|
||||
// Replace
|
||||
var creditcard3 = CreditCard{
|
||||
Number: "411111111113",
|
||||
}
|
||||
DB.Model(&user).Association("CreditCard").Replace(&creditcard3)
|
||||
|
||||
if creditcard3.ID == 0 {
|
||||
t.Errorf("Creditcard should has ID when created with Replace")
|
||||
}
|
||||
|
||||
var creditcard31 CreditCard
|
||||
DB.Model(&user).Related(&creditcard31)
|
||||
if creditcard31.Number != "411111111113" {
|
||||
t.Errorf("CreditCard should be updated with Replace")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||
t.Errorf("User's credit card count should be 1")
|
||||
}
|
||||
|
||||
// Delete
|
||||
DB.Model(&user).Association("CreditCard").Delete(&creditcard2)
|
||||
var creditcard4 CreditCard
|
||||
DB.Model(&user).Related(&creditcard4)
|
||||
if creditcard4.Number != "411111111113" {
|
||||
t.Errorf("Should not delete credit card when Delete a unrelated CreditCard")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||
t.Errorf("User's credit card count should be 1")
|
||||
}
|
||||
|
||||
DB.Model(&user).Association("CreditCard").Delete(&creditcard3)
|
||||
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||
t.Errorf("Should delete credit card with Delete")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 0 {
|
||||
t.Errorf("User's credit card count should be 0 after Delete")
|
||||
}
|
||||
|
||||
// Clear
|
||||
var creditcard5 = CreditCard{
|
||||
Number: "411111111115",
|
||||
}
|
||||
DB.Model(&user).Association("CreditCard").Append(&creditcard5)
|
||||
|
||||
if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||
t.Errorf("Should added credit card with Append")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||
t.Errorf("User's credit card count should be 1")
|
||||
}
|
||||
|
||||
DB.Model(&user).Association("CreditCard").Clear()
|
||||
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||
t.Errorf("Credit card should be deleted with Clear")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("CreditCard").Count() != 0 {
|
||||
t.Errorf("User's credit card count should be 0 after Clear")
|
||||
}
|
||||
|
||||
// Check Association mode with soft delete
|
||||
var creditcard6 = CreditCard{
|
||||
Number: "411111111116",
|
||||
}
|
||||
DB.Model(&user).Association("CreditCard").Append(&creditcard6)
|
||||
|
||||
if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 {
|
||||
t.Errorf("User's credit card count should be 1 after Append, but got %v", count)
|
||||
}
|
||||
|
||||
DB.Delete(&creditcard6)
|
||||
|
||||
if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 {
|
||||
t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil {
|
||||
t.Errorf("User's creditcard is not findable after Delete")
|
||||
}
|
||||
|
||||
if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 {
|
||||
t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count)
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil {
|
||||
t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile Profile `gorm:"ForeignKey:UserRefer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_one" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserID uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_one" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasMany(t *testing.T) {
|
||||
post := Post{
|
||||
Title: "post has many",
|
||||
Body: "body has many",
|
||||
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Error("Got errors when save post", err)
|
||||
}
|
||||
|
||||
for _, comment := range post.Comments {
|
||||
if comment.PostId == 0 {
|
||||
t.Errorf("comment's PostID should be updated")
|
||||
}
|
||||
}
|
||||
|
||||
var compareComments = func(comments []Comment, contents []string) bool {
|
||||
var commentContents []string
|
||||
for _, comment := range comments {
|
||||
commentContents = append(commentContents, comment.Content)
|
||||
}
|
||||
sort.Strings(commentContents)
|
||||
sort.Strings(contents)
|
||||
return reflect.DeepEqual(commentContents, contents)
|
||||
}
|
||||
|
||||
// Query
|
||||
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||
t.Errorf("Comment 1 should be saved")
|
||||
}
|
||||
|
||||
var comments1 []Comment
|
||||
DB.Model(&post).Association("Comments").Find(&comments1)
|
||||
if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) {
|
||||
t.Errorf("Query has many relations with Association")
|
||||
}
|
||||
|
||||
var comments11 []Comment
|
||||
DB.Model(&post).Related(&comments11)
|
||||
if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) {
|
||||
t.Errorf("Query has many relations with Related")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Comments").Count() != 2 {
|
||||
t.Errorf("Post's comments count should be 2")
|
||||
}
|
||||
|
||||
// Append
|
||||
DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"})
|
||||
|
||||
var comments2 []Comment
|
||||
DB.Model(&post).Related(&comments2)
|
||||
if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) {
|
||||
t.Errorf("Append new record to has many relations")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Comments").Count() != 3 {
|
||||
t.Errorf("Post's comments count should be 3 after Append")
|
||||
}
|
||||
|
||||
// Delete
|
||||
DB.Model(&post).Association("Comments").Delete(comments11)
|
||||
|
||||
var comments3 []Comment
|
||||
DB.Model(&post).Related(&comments3)
|
||||
if !compareComments(comments3, []string{"Comment 3"}) {
|
||||
t.Errorf("Delete an existing resource for has many relations")
|
||||
}
|
||||
|
||||
if DB.Model(&post).Association("Comments").Count() != 1 {
|
||||
t.Errorf("Post's comments count should be 1 after Delete 2")
|
||||
}
|
||||
|
||||
// Replace
|
||||
DB.Model(&Post{Id: 999}).Association("Comments").Replace()
|
||||
|
||||
var comments4 []Comment
|
||||
DB.Model(&post).Related(&comments4)
|
||||
if len(comments4) == 0 {
|
||||
t.Errorf("Replace for other resource should not clear all comments")
|
||||
}
|
||||
|
||||
DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"})
|
||||
|
||||
var comments41 []Comment
|
||||
DB.Model(&post).Related(&comments41)
|
||||
if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) {
|
||||
t.Errorf("Replace has many relations")
|
||||
}
|
||||
|
||||
// Clear
|
||||
DB.Model(&Post{Id: 999}).Association("Comments").Clear()
|
||||
|
||||
var comments5 []Comment
|
||||
DB.Model(&post).Related(&comments5)
|
||||
if len(comments5) == 0 {
|
||||
t.Errorf("Clear should not clear all comments")
|
||||
}
|
||||
|
||||
DB.Model(&post).Association("Comments").Clear()
|
||||
|
||||
var comments51 []Comment
|
||||
DB.Model(&post).Related(&comments51)
|
||||
if len(comments51) != 0 {
|
||||
t.Errorf("Clear has many relations")
|
||||
}
|
||||
|
||||
// Check Association mode with soft delete
|
||||
var comment6 = Comment{
|
||||
Content: "comment 6",
|
||||
}
|
||||
DB.Model(&post).Association("Comments").Append(&comment6)
|
||||
|
||||
if count := DB.Model(&post).Association("Comments").Count(); count != 1 {
|
||||
t.Errorf("post's comments count should be 1 after Append, but got %v", count)
|
||||
}
|
||||
|
||||
DB.Delete(&comment6)
|
||||
|
||||
if count := DB.Model(&post).Association("Comments").Count(); count != 0 {
|
||||
t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count)
|
||||
}
|
||||
|
||||
var comments6 []Comment
|
||||
if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 {
|
||||
t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6))
|
||||
}
|
||||
|
||||
if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 {
|
||||
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count)
|
||||
}
|
||||
|
||||
var comments61 []Comment
|
||||
if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 {
|
||||
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasManyOverrideForeignKey1(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserRefer uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Profile []Profile `gorm:"ForeignKey:UserRefer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_many" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasManyOverrideForeignKey2(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
UserID uint
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||
}
|
||||
|
||||
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||
if relation.Relationship.Kind != "has_many" ||
|
||||
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||
t.Errorf("Override belongs to foreign key with tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyToMany(t *testing.T) {
|
||||
DB.Raw("delete from languages")
|
||||
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
||||
user := User{Name: "Many2Many", Languages: languages}
|
||||
DB.Save(&user)
|
||||
|
||||
// Query
|
||||
var newLanguages []Language
|
||||
DB.Model(&user).Related(&newLanguages, "Languages")
|
||||
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||
t.Errorf("Query many to many relations")
|
||||
}
|
||||
|
||||
DB.Model(&user).Association("Languages").Find(&newLanguages)
|
||||
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||
t.Errorf("Should be able to find many to many relations")
|
||||
}
|
||||
|
||||
if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
|
||||
t.Errorf("Count should return correct result")
|
||||
}
|
||||
|
||||
// Append
|
||||
DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
|
||||
if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
|
||||
t.Errorf("New record should be saved when append")
|
||||
}
|
||||
|
||||
languageA := Language{Name: "AA"}
|
||||
DB.Save(&languageA)
|
||||
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
|
||||
|
||||
languageC := Language{Name: "CC"}
|
||||
DB.Save(&languageC)
|
||||
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
|
||||
|
||||
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
|
||||
|
||||
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
|
||||
|
||||
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
|
||||
t.Errorf("All appended languages should be saved")
|
||||
}
|
||||
|
||||
// Delete
|
||||
user.Languages = []Language{}
|
||||
DB.Model(&user).Association("Languages").Find(&user.Languages)
|
||||
|
||||
var language Language
|
||||
DB.Where("name = ?", "EE").First(&language)
|
||||
DB.Model(&user).Association("Languages").Delete(language, &language)
|
||||
|
||||
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
|
||||
t.Errorf("Relations should be deleted with Delete")
|
||||
}
|
||||
if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
|
||||
t.Errorf("Language EE should not be deleted")
|
||||
}
|
||||
|
||||
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
|
||||
|
||||
user2 := User{Name: "Many2Many_User2", Languages: languages}
|
||||
DB.Save(&user2)
|
||||
|
||||
DB.Model(&user).Association("Languages").Delete(languages, &languages)
|
||||
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
|
||||
t.Errorf("Relations should be deleted with Delete")
|
||||
}
|
||||
|
||||
if DB.Model(&user2).Association("Languages").Count() == 0 {
|
||||
t.Errorf("Other user's relations should not be deleted")
|
||||
}
|
||||
|
||||
// Replace
|
||||
var languageB Language
|
||||
DB.Where("name = ?", "BB").First(&languageB)
|
||||
DB.Model(&user).Association("Languages").Replace(languageB)
|
||||
if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
|
||||
t.Errorf("Relations should be replaced")
|
||||
}
|
||||
|
||||
DB.Model(&user).Association("Languages").Replace()
|
||||
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
|
||||
t.Errorf("Relations should be replaced with empty")
|
||||
}
|
||||
|
||||
DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
|
||||
if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
|
||||
t.Errorf("Relations should be replaced")
|
||||
}
|
||||
|
||||
// Clear
|
||||
DB.Model(&user).Association("Languages").Clear()
|
||||
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
|
||||
t.Errorf("Relations should be cleared")
|
||||
}
|
||||
|
||||
// Check Association mode with soft delete
|
||||
var language6 = Language{
|
||||
Name: "language 6",
|
||||
}
|
||||
DB.Model(&user).Association("Languages").Append(&language6)
|
||||
|
||||
if count := DB.Model(&user).Association("Languages").Count(); count != 1 {
|
||||
t.Errorf("user's languages count should be 1 after Append, but got %v", count)
|
||||
}
|
||||
|
||||
DB.Delete(&language6)
|
||||
|
||||
if count := DB.Model(&user).Association("Languages").Count(); count != 0 {
|
||||
t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count)
|
||||
}
|
||||
|
||||
var languages6 []Language
|
||||
if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 {
|
||||
t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6))
|
||||
}
|
||||
|
||||
if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 {
|
||||
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count)
|
||||
}
|
||||
|
||||
var languages61 []Language
|
||||
if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 {
|
||||
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelated(t *testing.T) {
|
||||
user := User{
|
||||
Name: "jinzhu",
|
||||
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
||||
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
|
||||
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
|
||||
CreditCard: CreditCard{Number: "1234567890"},
|
||||
Company: Company{Name: "company1"},
|
||||
}
|
||||
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
t.Errorf("No error should happen when saving user")
|
||||
}
|
||||
|
||||
if user.CreditCard.ID == 0 {
|
||||
t.Errorf("After user save, credit card should have id")
|
||||
}
|
||||
|
||||
if user.BillingAddress.ID == 0 {
|
||||
t.Errorf("After user save, billing address should have id")
|
||||
}
|
||||
|
||||
if user.Emails[0].Id == 0 {
|
||||
t.Errorf("After user save, billing address should have id")
|
||||
}
|
||||
|
||||
var emails []Email
|
||||
DB.Model(&user).Related(&emails)
|
||||
if len(emails) != 2 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var emails2 []Email
|
||||
DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
|
||||
if len(emails2) != 1 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var emails3 []*Email
|
||||
DB.Model(&user).Related(&emails3)
|
||||
if len(emails3) != 2 {
|
||||
t.Errorf("Should have two emails")
|
||||
}
|
||||
|
||||
var user1 User
|
||||
DB.Model(&user).Related(&user1.Emails)
|
||||
if len(user1.Emails) != 2 {
|
||||
t.Errorf("Should have only one email match related condition")
|
||||
}
|
||||
|
||||
var address1 Address
|
||||
DB.Model(&user).Related(&address1, "BillingAddressId")
|
||||
if address1.Address1 != "Billing Address - Address 1" {
|
||||
t.Errorf("Should get billing address from user correctly")
|
||||
}
|
||||
|
||||
user1 = User{}
|
||||
DB.Model(&address1).Related(&user1, "BillingAddressId")
|
||||
if DB.NewRecord(user1) {
|
||||
t.Errorf("Should get user from address correctly")
|
||||
}
|
||||
|
||||
var user2 User
|
||||
DB.Model(&emails[0]).Related(&user2)
|
||||
if user2.Id != user.Id || user2.Name != user.Name {
|
||||
t.Errorf("Should get user from email correctly")
|
||||
}
|
||||
|
||||
var creditcard CreditCard
|
||||
var user3 User
|
||||
DB.First(&creditcard, "number = ?", "1234567890")
|
||||
DB.Model(&creditcard).Related(&user3)
|
||||
if user3.Id != user.Id || user3.Name != user.Name {
|
||||
t.Errorf("Should get user from credit card correctly")
|
||||
}
|
||||
|
||||
if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
|
||||
t.Errorf("RecordNotFound for Related")
|
||||
}
|
||||
|
||||
var company Company
|
||||
if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
|
||||
t.Errorf("RecordNotFound for Related")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKey(t *testing.T) {
|
||||
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"UserId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
|
||||
for _, foreignKey := range []string{"PostId"} {
|
||||
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
||||
// sqlite does not support ADD CONSTRAINT in ALTER TABLE
|
||||
return
|
||||
}
|
||||
targetScope := DB.NewScope(target)
|
||||
targetTableName := targetScope.TableName()
|
||||
modelScope := DB.NewScope(source)
|
||||
modelField, ok := modelScope.FieldByName(sourceFieldName)
|
||||
if !ok {
|
||||
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName))
|
||||
}
|
||||
targetField, ok := targetScope.FieldByName(targetFieldName)
|
||||
if !ok {
|
||||
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName))
|
||||
}
|
||||
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
|
||||
err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
|
||||
if err != nil {
|
||||
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLongForeignKey(t *testing.T) {
|
||||
testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID")
|
||||
}
|
||||
|
||||
func TestLongForeignKeyWithShortDest(t *testing.T) {
|
||||
testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID")
|
||||
}
|
||||
|
||||
func TestHasManyChildrenWithOneStruct(t *testing.T) {
|
||||
category := Category{
|
||||
Name: "main",
|
||||
Categories: []Category{
|
||||
{Name: "sub1"},
|
||||
{Name: "sub2"},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Save(&category)
|
||||
}
|
||||
|
||||
func TestSkipSaveAssociation(t *testing.T) {
|
||||
type Company struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
CompanyID uint
|
||||
Company Company `gorm:"save_associations:false"`
|
||||
}
|
||||
DB.AutoMigrate(&Company{}, &User{})
|
||||
|
||||
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
|
||||
|
||||
if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
|
||||
t.Errorf("Company skip_save_association should not been saved")
|
||||
}
|
||||
}
|
||||
242
callback.go
242
callback.go
@ -1,242 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import "log"
|
||||
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{}
|
||||
|
||||
// Callback is a struct that contains all CRUD callbacks
|
||||
// Field `creates` contains callbacks will be call when creating object
|
||||
// Field `updates` contains callbacks will be call when updating object
|
||||
// Field `deletes` contains callbacks will be call when deleting object
|
||||
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||
type Callback struct {
|
||||
creates []*func(scope *Scope)
|
||||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
rowQueries []*func(scope *Scope)
|
||||
processors []*CallbackProcessor
|
||||
}
|
||||
|
||||
// CallbackProcessor contains callback informations
|
||||
type CallbackProcessor struct {
|
||||
name string // current callback's name
|
||||
before string // register current callback before a callback
|
||||
after string // register current callback after a callback
|
||||
replace bool // replace callbacks with same name
|
||||
remove bool // delete callbacks with same name
|
||||
kind string // callback type: create, update, delete, query, row_query
|
||||
processor *func(scope *Scope) // callback handler
|
||||
parent *Callback
|
||||
}
|
||||
|
||||
func (c *Callback) clone() *Callback {
|
||||
return &Callback{
|
||||
creates: c.creates,
|
||||
updates: c.updates,
|
||||
deletes: c.deletes,
|
||||
queries: c.queries,
|
||||
rowQueries: c.rowQueries,
|
||||
processors: c.processors,
|
||||
}
|
||||
}
|
||||
|
||||
// Create could be used to register callbacks for creating object
|
||||
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||
// // business logic
|
||||
// ...
|
||||
//
|
||||
// // set error if some thing wrong happened, will rollback the creating
|
||||
// scope.Err(errors.New("error"))
|
||||
// })
|
||||
func (c *Callback) Create() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "create", parent: c}
|
||||
}
|
||||
|
||||
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||
func (c *Callback) Update() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "update", parent: c}
|
||||
}
|
||||
|
||||
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||
func (c *Callback) Delete() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "delete", parent: c}
|
||||
}
|
||||
|
||||
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||
// Refer `Create` for usage
|
||||
func (c *Callback) Query() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "query", parent: c}
|
||||
}
|
||||
|
||||
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||
return &CallbackProcessor{kind: "row_query", parent: c}
|
||||
}
|
||||
|
||||
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||
cp.after = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||
cp.before = callbackName
|
||||
return cp
|
||||
}
|
||||
|
||||
// Register a new callback, refer `Callbacks.Create`
|
||||
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||
if cp.kind == "row_query" {
|
||||
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
|
||||
log.Printf("Registing RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName)
|
||||
cp.before = "gorm:row_query"
|
||||
}
|
||||
}
|
||||
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Remove a registered callback
|
||||
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||
log.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.remove = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Replace a registered callback with new callback
|
||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||
// scope.SetColumn("Created", now)
|
||||
// scope.SetColumn("Updated", now)
|
||||
// })
|
||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||
log.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.replace = true
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
cp.parent.reorder()
|
||||
}
|
||||
|
||||
// Get registered callback
|
||||
// db.Callback().Create().Get("gorm:create")
|
||||
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||
for _, p := range cp.parent.processors {
|
||||
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||
return *p.processor
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||
var (
|
||||
allNames, sortedNames []string
|
||||
sortCallbackProcessor func(c *CallbackProcessor)
|
||||
)
|
||||
|
||||
for _, cp := range cps {
|
||||
// show warning message the callback name already exists
|
||||
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||
log.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||
}
|
||||
allNames = append(allNames, cp.name)
|
||||
}
|
||||
|
||||
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
|
||||
if c.before != "" { // if defined before callback
|
||||
if index := getRIndex(sortedNames, c.before); index != -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||
} else if index := getRIndex(allNames, c.before); index != -1 {
|
||||
// if before callback exists but haven't sorted, append current callback to last
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
sortCallbackProcessor(cps[index])
|
||||
}
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if index := getRIndex(sortedNames, c.after); index != -1 {
|
||||
// if after callback already sorted, append current callback just before it
|
||||
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||
} else if index := getRIndex(allNames, c.after); index != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
cp := cps[index]
|
||||
// set after callback's before callback to current callback
|
||||
if cp.before == "" {
|
||||
cp.before = c.name
|
||||
}
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sortedNames, c.name) == -1 {
|
||||
sortedNames = append(sortedNames, c.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, cp := range cps {
|
||||
sortCallbackProcessor(cp)
|
||||
}
|
||||
|
||||
var sortedFuncs []*func(scope *Scope)
|
||||
for _, name := range sortedNames {
|
||||
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||
}
|
||||
}
|
||||
|
||||
return sortedFuncs
|
||||
}
|
||||
|
||||
// reorder all registered processors, and reset CRUD callbacks
|
||||
func (c *Callback) reorder() {
|
||||
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||
|
||||
for _, processor := range c.processors {
|
||||
if processor.name != "" {
|
||||
switch processor.kind {
|
||||
case "create":
|
||||
creates = append(creates, processor)
|
||||
case "update":
|
||||
updates = append(updates, processor)
|
||||
case "delete":
|
||||
deletes = append(deletes, processor)
|
||||
case "query":
|
||||
queries = append(queries, processor)
|
||||
case "row_query":
|
||||
rowQueries = append(rowQueries, processor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.creates = sortProcessors(creates)
|
||||
c.updates = sortProcessors(updates)
|
||||
c.deletes = sortProcessors(deletes)
|
||||
c.queries = sortProcessors(queries)
|
||||
c.rowQueries = sortProcessors(rowQueries)
|
||||
}
|
||||
@ -1,163 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for creating
|
||||
func init() {
|
||||
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||
func beforeCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeCreate")
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
now := NowFunc()
|
||||
|
||||
if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
|
||||
if createdAtField.IsBlank {
|
||||
createdAtField.Set(now)
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
|
||||
if updatedAtField.IsBlank {
|
||||
updatedAtField.Set(now)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createCallback the callback used to insert data into database
|
||||
func createCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
columns, placeholders []string
|
||||
blankColumnsWithDefaultValue []string
|
||||
)
|
||||
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if field.IsNormal {
|
||||
if field.IsBlank && field.HasDefaultValue {
|
||||
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||
} else if !field.IsPrimaryKey || !field.IsBlank {
|
||||
columns = append(columns, scope.Quote(field.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||
}
|
||||
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
returningColumn = "*"
|
||||
quotedTableName = scope.QuotedTableName()
|
||||
primaryField = scope.PrimaryField()
|
||||
extraOption string
|
||||
)
|
||||
|
||||
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if primaryField != nil {
|
||||
returningColumn = scope.Quote(primaryField.DBName)
|
||||
}
|
||||
|
||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||
|
||||
if len(columns) == 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||
quotedTableName,
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
||||
scope.QuotedTableName(),
|
||||
strings.Join(columns, ","),
|
||||
strings.Join(placeholders, ","),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
// execute create sql
|
||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
// set rows affected count
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
// set primary value to primary field
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||
scope.Err(primaryField.Set(primaryValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField.Field.CanAddr() {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
primaryField.IsBlank = false
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
} else {
|
||||
scope.Err(ErrUnaddressable)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsPrimaryKey && !field.IsBlank {
|
||||
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
db.Scan(scope.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||
func afterCreateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterCreate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
@ -1,63 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Define callbacks for deleting
|
||||
func init() {
|
||||
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||
func beforeDeleteCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while deleting"))
|
||||
return
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeDelete")
|
||||
}
|
||||
}
|
||||
|
||||
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||
func deleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt")
|
||||
|
||||
if !scope.Search.Unscoped && hasDeletedAtField {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v=%v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
scope.Quote(deletedAtField.DBName),
|
||||
scope.AddToVars(NowFunc()),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
} else {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"DELETE FROM %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||
func afterDeleteCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterDelete")
|
||||
}
|
||||
}
|
||||
@ -1,95 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Define callbacks for querying
|
||||
func init() {
|
||||
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func queryCallback(scope *Scope) {
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice, isPtr bool
|
||||
resultType reflect.Type
|
||||
results = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||
results = indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if kind := results.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
resultType = results.Type().Elem()
|
||||
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||
|
||||
if resultType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
resultType = resultType.Elem()
|
||||
}
|
||||
} else if kind != reflect.Struct {
|
||||
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||
return
|
||||
}
|
||||
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.db.RowsAffected = 0
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
||||
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
scope.db.RowsAffected++
|
||||
|
||||
elem := results
|
||||
if isSlice {
|
||||
elem = reflect.New(resultType).Elem()
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
||||
|
||||
if isSlice {
|
||||
if isPtr {
|
||||
results.Set(reflect.Append(results, elem.Addr()))
|
||||
} else {
|
||||
results.Set(reflect.Append(results, elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
} else if scope.db.RowsAffected == 0 && !isSlice {
|
||||
scope.Err(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||
func afterQueryCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterFind")
|
||||
}
|
||||
}
|
||||
@ -1,380 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// preloadCallback used to preload associations
|
||||
func preloadCallback(scope *Scope) {
|
||||
|
||||
if _, ok := scope.Get("gorm:auto_preload"); ok {
|
||||
autoPreload(scope)
|
||||
}
|
||||
|
||||
if scope.Search.preload == nil || scope.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
preloadedMap = map[string]bool{}
|
||||
fields = scope.Fields()
|
||||
)
|
||||
|
||||
for _, preload := range scope.Search.preload {
|
||||
var (
|
||||
preloadFields = strings.Split(preload.schema, ".")
|
||||
currentScope = scope
|
||||
currentFields = fields
|
||||
)
|
||||
|
||||
for idx, preloadField := range preloadFields {
|
||||
var currentPreloadConditions []interface{}
|
||||
|
||||
if currentScope == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if not preloaded
|
||||
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||
|
||||
// assign search conditions to last preload
|
||||
if idx == len(preloadFields)-1 {
|
||||
currentPreloadConditions = preload.conditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != preloadField || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||
case "many_to_many":
|
||||
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||
default:
|
||||
scope.Err(errors.New("unsupported relation"))
|
||||
}
|
||||
|
||||
preloadedMap[preloadKey] = true
|
||||
break
|
||||
}
|
||||
|
||||
if !preloadedMap[preloadKey] {
|
||||
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// preload next level
|
||||
if idx < len(preloadFields)-1 {
|
||||
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||
if currentScope != nil {
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func autoPreload(scope *Scope) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["PRELOAD"]; ok {
|
||||
if preload, err := strconv.ParseBool(val); err != nil {
|
||||
scope.Err(errors.New("invalid preload option"))
|
||||
return
|
||||
} else if !preload {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
scope.Search.Preload(field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||
var (
|
||||
preloadDB = scope.NewDB()
|
||||
preloadConditions []interface{}
|
||||
)
|
||||
|
||||
for _, condition := range conditions {
|
||||
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||
preloadDB = scopes(preloadDB)
|
||||
} else {
|
||||
preloadConditions = append(preloadConditions, condition)
|
||||
}
|
||||
}
|
||||
|
||||
return preloadDB, preloadConditions
|
||||
}
|
||||
|
||||
// handleHasOnePreload used to preload has one associations
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||
indirectValue.FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleHasManyPreload used to preload has many associations
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// find relations
|
||||
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||
values := toQueryValues(primaryKeys)
|
||||
if relation.PolymorphicType != "" {
|
||||
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||
values = append(values, relation.PolymorphicValue)
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
preloadMap := make(map[string][]reflect.Value)
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
||||
}
|
||||
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
||||
f := object.FieldByName(field.Name)
|
||||
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
||||
f.Set(reflect.Append(f, results...))
|
||||
} else {
|
||||
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(resultsValue))
|
||||
}
|
||||
}
|
||||
|
||||
// handleBelongsToPreload used to preload belongs to associations
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// get relations's primary keys
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// find relations
|
||||
results := makeSlice(field.Struct.Type)
|
||||
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
resultsValue = indirect(reflect.ValueOf(results))
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
)
|
||||
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
result := resultsValue.Index(i)
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.Err(field.Set(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleManyToManyPreload used to preload many to many associations
|
||||
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||
var (
|
||||
relation = field.Relationship
|
||||
joinTableHandler = relation.JoinTableHandler
|
||||
fieldType = field.Struct.Type.Elem()
|
||||
foreignKeyValue interface{}
|
||||
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||
linkHash = map[string][]reflect.Value{}
|
||||
isPtr bool
|
||||
)
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
var sourceKeys = []string{}
|
||||
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||
sourceKeys = append(sourceKeys, key.DBName)
|
||||
}
|
||||
|
||||
// preload conditions
|
||||
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||
|
||||
// generate query with join table
|
||||
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
|
||||
|
||||
if len(preloadDB.search.selects) == 0 {
|
||||
preloadDB = preloadDB.Select("*")
|
||||
}
|
||||
|
||||
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||
|
||||
// preload inline conditions
|
||||
if len(preloadConditions) > 0 {
|
||||
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||
}
|
||||
|
||||
rows, err := preloadDB.Rows()
|
||||
|
||||
if scope.Err(err) != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
columns, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var (
|
||||
elem = reflect.New(fieldType).Elem()
|
||||
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||
)
|
||||
|
||||
// register foreign keys in join tables
|
||||
var joinTableFields []*Field
|
||||
for _, sourceKey := range sourceKeys {
|
||||
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
||||
}
|
||||
|
||||
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||
|
||||
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||
// generate hashed forkey keys in join table
|
||||
for idx, joinTableField := range joinTableFields {
|
||||
if !joinTableField.Field.IsNil() {
|
||||
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
||||
}
|
||||
}
|
||||
hashedSourceKeys := toString(foreignKeys)
|
||||
|
||||
if isPtr {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||
} else {
|
||||
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
scope.Err(err)
|
||||
}
|
||||
|
||||
// assign find results
|
||||
var (
|
||||
indirectScopeValue = scope.IndirectValue()
|
||||
fieldsSourceMap = map[string][]reflect.Value{}
|
||||
foreignFieldNames = []string{}
|
||||
)
|
||||
|
||||
for _, dbName := range relation.ForeignFieldNames {
|
||||
if field, ok := scope.FieldByName(dbName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if indirectScopeValue.Kind() == reflect.Slice {
|
||||
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||
object := indirect(indirectScopeValue.Index(j))
|
||||
key := toString(getValueFromFields(object, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
||||
}
|
||||
} else if indirectScopeValue.IsValid() {
|
||||
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||
}
|
||||
for source, link := range linkHash {
|
||||
for i, field := range fieldsSourceMap[source] {
|
||||
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||
if fieldsSourceMap[source][i].Len() != 0 {
|
||||
continue
|
||||
}
|
||||
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@ -1,30 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Define callbacks for row query
|
||||
func init() {
|
||||
DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback)
|
||||
}
|
||||
|
||||
type RowQueryResult struct {
|
||||
Row *sql.Row
|
||||
}
|
||||
|
||||
type RowsQueryResult struct {
|
||||
Rows *sql.Rows
|
||||
Error error
|
||||
}
|
||||
|
||||
// queryCallback used to query data from database
|
||||
func rowQueryCallback(scope *Scope) {
|
||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if rowResult, ok := result.(*RowQueryResult); ok {
|
||||
rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
|
||||
} else if rowsResult, ok := result.(*RowsQueryResult); ok {
|
||||
rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,99 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import "reflect"
|
||||
|
||||
func beginTransactionCallback(scope *Scope) {
|
||||
scope.Begin()
|
||||
}
|
||||
|
||||
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||
scope.CommitOrRollback()
|
||||
}
|
||||
|
||||
func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
|
||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||
if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") {
|
||||
if relationship := field.Relationship; relationship != nil {
|
||||
return true, relationship
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||
if !scope.shouldSaveAssociations() {
|
||||
return
|
||||
}
|
||||
for _, field := range scope.Fields() {
|
||||
if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" {
|
||||
fieldValue := field.Field.Addr().Interface()
|
||||
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
// set value's foreign key
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func saveAfterAssociationsCallback(scope *Scope) {
|
||||
if !scope.shouldSaveAssociations() {
|
||||
return
|
||||
}
|
||||
for _, field := range scope.Fields() {
|
||||
if ok, relationship := saveFieldAsAssociation(scope, field); ok &&
|
||||
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||
value := field.Field
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
newDB := scope.NewDB()
|
||||
elem := value.Index(i).Addr().Interface()
|
||||
newScope := newDB.NewScope(elem)
|
||||
|
||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
|
||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||
}
|
||||
}
|
||||
default:
|
||||
elem := value.Addr().Interface()
|
||||
newScope := scope.New(elem)
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||
}
|
||||
scope.Err(scope.NewDB().Save(elem).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,112 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
||||
var names []string
|
||||
for _, f := range funcs {
|
||||
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
||||
names = append(names, fnames[len(fnames)-1])
|
||||
}
|
||||
return reflect.DeepEqual(names, fnames)
|
||||
}
|
||||
|
||||
func create(s *Scope) {}
|
||||
func beforeCreate1(s *Scope) {}
|
||||
func beforeCreate2(s *Scope) {}
|
||||
func afterCreate1(s *Scope) {}
|
||||
func afterCreate2(s *Scope) {}
|
||||
|
||||
func TestRegisterCallback(t *testing.T) {
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("before_create2", beforeCreate2)
|
||||
callback.Create().Register("create", create)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||
t.Errorf("register callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||
var callback1 = &Callback{}
|
||||
callback1.Create().Register("before_create1", beforeCreate1)
|
||||
callback1.Create().Register("create", create)
|
||||
callback1.Create().Register("after_create1", afterCreate1)
|
||||
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
||||
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &Callback{}
|
||||
|
||||
callback2.Update().Register("create", create)
|
||||
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
||||
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
||||
callback2.Update().Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||
var callback1 = &Callback{}
|
||||
|
||||
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback1.Query().Register("before_create1", beforeCreate1)
|
||||
callback1.Query().Register("after_create1", afterCreate1)
|
||||
|
||||
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
|
||||
var callback2 = &Callback{}
|
||||
|
||||
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||
callback2.Delete().Register("after_create1", afterCreate1)
|
||||
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||
|
||||
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||
t.Errorf("register callback with order")
|
||||
}
|
||||
}
|
||||
|
||||
func replaceCreate(s *Scope) {}
|
||||
|
||||
func TestReplaceCallback(t *testing.T) {
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Replace("create", replaceCreate)
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
||||
t.Errorf("replace callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveCallback(t *testing.T) {
|
||||
var callback = &Callback{}
|
||||
|
||||
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||
callback.Create().Register("before_create1", beforeCreate1)
|
||||
callback.Create().Register("after_create1", afterCreate1)
|
||||
callback.Create().Remove("create")
|
||||
|
||||
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
||||
t.Errorf("remove callback")
|
||||
}
|
||||
}
|
||||
@ -1,117 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Define callbacks for updating
|
||||
func init() {
|
||||
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||
}
|
||||
|
||||
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||
} else {
|
||||
scope.SkipLeft()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||
func beforeUpdateCallback(scope *Scope) {
|
||||
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
|
||||
scope.Err(errors.New("Missing WHERE clause while updating"))
|
||||
return
|
||||
}
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeSave")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("BeforeUpdate")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
scope.SetColumn("UpdatedAt", NowFunc())
|
||||
}
|
||||
}
|
||||
|
||||
// updateCallback the callback used to update data to database
|
||||
func updateCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
var sqls []string
|
||||
|
||||
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||
}
|
||||
} else {
|
||||
for _, field := range scope.Fields() {
|
||||
if scope.changeableField(field) {
|
||||
if !field.IsPrimaryKey && field.IsNormal {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
for _, foreignKey := range relationship.ForeignDBNames {
|
||||
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||
sqls = append(sqls,
|
||||
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var extraOption string
|
||||
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||
extraOption = fmt.Sprint(str)
|
||||
}
|
||||
|
||||
if len(sqls) > 0 {
|
||||
joinSQL := scope.joinsSQL()
|
||||
whereSQL := scope.whereSQL()
|
||||
if scope.Search.raw {
|
||||
whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")")
|
||||
}
|
||||
combinedSql := whereSQL + scope.groupSQL() +
|
||||
scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL()
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v %v SET %v%v%v",
|
||||
scope.QuotedTableName(),
|
||||
addExtraSpaceIfExist(joinSQL),
|
||||
strings.Join(sqls, ", "),
|
||||
addExtraSpaceIfExist(combinedSql),
|
||||
addExtraSpaceIfExist(extraOption),
|
||||
)).Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||
func afterUpdateCallback(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterUpdate")
|
||||
}
|
||||
if !scope.HasError() {
|
||||
scope.CallMethod("AfterSave")
|
||||
}
|
||||
}
|
||||
}
|
||||
327
callbacks.go
Normal file
327
callbacks.go
Normal file
@ -0,0 +1,327 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func initializeCallbacks(db *DB) *callbacks {
|
||||
return &callbacks{
|
||||
processors: map[string]*processor{
|
||||
"create": {db: db},
|
||||
"query": {db: db},
|
||||
"update": {db: db},
|
||||
"delete": {db: db},
|
||||
"row": {db: db},
|
||||
"raw": {db: db},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// callbacks gorm callbacks manager
|
||||
type callbacks struct {
|
||||
processors map[string]*processor
|
||||
}
|
||||
|
||||
type processor struct {
|
||||
db *DB
|
||||
Clauses []string
|
||||
fns []func(*DB)
|
||||
callbacks []*callback
|
||||
}
|
||||
|
||||
type callback struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
remove bool
|
||||
replace bool
|
||||
match func(*DB) bool
|
||||
handler func(*DB)
|
||||
processor *processor
|
||||
}
|
||||
|
||||
func (cs *callbacks) Create() *processor {
|
||||
return cs.processors["create"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Query() *processor {
|
||||
return cs.processors["query"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Update() *processor {
|
||||
return cs.processors["update"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Delete() *processor {
|
||||
return cs.processors["delete"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Row() *processor {
|
||||
return cs.processors["row"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Raw() *processor {
|
||||
return cs.processors["raw"]
|
||||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
curTime = time.Now()
|
||||
stmt = db.Statement
|
||||
resetBuildClauses bool
|
||||
)
|
||||
|
||||
if len(stmt.BuildClauses) == 0 {
|
||||
stmt.BuildClauses = p.Clauses
|
||||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
} else if stmt.Dest == nil {
|
||||
stmt.Dest = stmt.Model
|
||||
}
|
||||
|
||||
// parse model values
|
||||
if stmt.Model != nil {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
|
||||
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assign stmt.ReflectValue
|
||||
if stmt.Dest != nil {
|
||||
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
||||
for stmt.ReflectValue.Kind() == reflect.Ptr {
|
||||
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
|
||||
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
|
||||
}
|
||||
|
||||
stmt.ReflectValue = stmt.ReflectValue.Elem()
|
||||
}
|
||||
if !stmt.ReflectValue.IsValid() {
|
||||
db.AddError(ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range p.fns {
|
||||
f(db)
|
||||
}
|
||||
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
|
||||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
}
|
||||
|
||||
if resetBuildClauses {
|
||||
stmt.BuildClauses = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
||||
if v := p.callbacks[i]; v.name == name && !v.remove {
|
||||
return v.handler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) Before(name string) *callback {
|
||||
return &callback{before: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) After(name string) *callback {
|
||||
return &callback{after: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Match(fc func(*DB) bool) *callback {
|
||||
return &callback{match: fc, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Register(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Register(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) Remove(name string) error {
|
||||
return (&callback{processor: p}).Remove(name)
|
||||
}
|
||||
|
||||
func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Replace(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *callback) Before(name string) *callback {
|
||||
c.before = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) After(name string) *callback {
|
||||
c.after = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) Register(name string, fn func(*DB)) error {
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
var (
|
||||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.Slice(cs, func(i, j int) bool {
|
||||
return cs[j].before == "*" || cs[j].after == "*"
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
|
||||
sortCallback = func(c *callback) error {
|
||||
if c.before != "" { // if defined before callback
|
||||
if c.before == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append([]string{c.name}, sorted...)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||
// if before callback exists
|
||||
cs[idx].after = c.name
|
||||
}
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if c.after == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
} else if curIdx < sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
// set after callback's before callback to current callback
|
||||
after := cs[idx]
|
||||
|
||||
if after.before == "" {
|
||||
after.before = c.name
|
||||
}
|
||||
|
||||
if err := sortCallback(after); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sortCallback(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sorted, c.name) == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, c := range cs {
|
||||
if err = sortCallback(c); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range sorted {
|
||||
if idx := getRIndex(names, name); !cs[idx].remove {
|
||||
fns = append(fns, cs[idx].handler)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
390
callbacks/associations.go
Normal file
390
callbacks/associations.go
Normal file
@ -0,0 +1,390 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(elem)
|
||||
db.AddError(ref.ForeignKey.Set(obj, pv))
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
dest[ref.ForeignKey.DBName] = pv
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
dest[rel.Name] = elem.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
|
||||
objs = append(objs, obj)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, rv)
|
||||
} else {
|
||||
elems = reflect.Append(elems, rv.Addr())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
||||
db.AddError(ref.ForeignKey.Set(rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
if f.Kind() != reflect.Ptr {
|
||||
f = f.Addr()
|
||||
}
|
||||
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
|
||||
ref.ForeignKey.Set(f, fv)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
ref.ForeignKey.Set(f, ref.PrimaryValue)
|
||||
}
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Has Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasMany {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(v)
|
||||
ref.ForeignKey.Set(elem, pv)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
ref.ForeignKey.Set(elem, ref.PrimaryValue)
|
||||
}
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
// Save Many2Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(obj)
|
||||
ref.ForeignKey.Set(joinValue, fv)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
|
||||
} else {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(elem)
|
||||
ref.ForeignKey.Set(joinValue, fv)
|
||||
}
|
||||
}
|
||||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
|
||||
objs = append(objs, v)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
if joins.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
}).Create(joins.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
|
||||
if stmt.DB.FullSaveAssociations {
|
||||
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
|
||||
for _, dbName := range s.DBNames {
|
||||
if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.LookUpField(dbName).PrimaryKey {
|
||||
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(defaultUpdatingColumns) > 0 {
|
||||
columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
columns = append(columns, clause.Column{Name: dbName})
|
||||
}
|
||||
|
||||
return clause.OnConflict{
|
||||
Columns: columns,
|
||||
DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns),
|
||||
}
|
||||
}
|
||||
|
||||
return clause.OnConflict{DoNothing: true}
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, refName) {
|
||||
columnName = strings.TrimPrefix(name, refName)
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selects = append(selects, columnName)
|
||||
} else {
|
||||
omits = append(omits, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
|
||||
FullSaveAssociations: db.FullSaveAssociations,
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
})
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if tx.Statement.FullSaveAssociations {
|
||||
tx = tx.InstanceSet("gorm:update_track_time", true)
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
} else if restricted && len(omits) == 0 {
|
||||
tx = tx.Omit(clause.Associations)
|
||||
}
|
||||
|
||||
if len(omits) > 0 {
|
||||
tx = tx.Omit(omits...)
|
||||
}
|
||||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
83
callbacks/callbacks.go
Normal file
83
callbacks/callbacks.go
Normal file
@ -0,0 +1,83 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
WithReturning bool
|
||||
CreateClauses []string
|
||||
QueryClauses []string
|
||||
UpdateClauses []string
|
||||
DeleteClauses []string
|
||||
}
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete)
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update)
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
rowCallback.Register("gorm:row", RowQuery)
|
||||
rowCallback.Clauses = config.QueryClauses
|
||||
|
||||
rawCallback := db.Callback().Raw()
|
||||
rawCallback.Register("gorm:raw", RawExec)
|
||||
rawCallback.Clauses = config.QueryClauses
|
||||
}
|
||||
23
callbacks/callmethod.go
Normal file
23
callbacks/callmethod.go
Normal file
@ -0,0 +1,23 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
tx := db.Session(&gorm.Session{NewDB: true})
|
||||
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
370
callbacks/create.go
Normal file
370
callbacks/create.go
Normal file
@ -0,0 +1,370 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
if config.WithReturning {
|
||||
return CreateWithReturning
|
||||
} else {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.RowsAffected > 0 {
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
|
||||
if isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
|
||||
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CreateWithReturning(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
|
||||
db.Statement.WriteString(" RETURNING ")
|
||||
|
||||
var (
|
||||
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
|
||||
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
|
||||
)
|
||||
|
||||
for idx, field := range sch.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
db.Statement.WriteByte(',')
|
||||
}
|
||||
|
||||
fields[idx] = field
|
||||
db.Statement.WriteQuoted(field.DBName)
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
db.RowsAffected = 0
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
c := db.Statement.Clauses["ON CONFLICT"]
|
||||
onConflict, _ := c.Expression.(clause.OnConflict)
|
||||
|
||||
for rows.Next() {
|
||||
BEGIN:
|
||||
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
|
||||
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
for idx, field := range fields {
|
||||
fieldValue := field.ReflectValueOf(reflectValue)
|
||||
|
||||
if onConflict.DoNothing && !fieldValue.IsZero() {
|
||||
db.RowsAffected++
|
||||
|
||||
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
|
||||
return
|
||||
}
|
||||
|
||||
goto BEGIN
|
||||
}
|
||||
|
||||
values[idx] = fieldValue.Addr().Interface()
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
if err := rows.Scan(values...); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for idx, field := range fields {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
} else if !db.DryRun && db.Error == nil {
|
||||
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case *map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||
case []map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
case *[]map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
curTime = stmt.DB.NowFunc()
|
||||
isZero bool
|
||||
)
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
|
||||
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
if stmt.ReflectValue.Len() == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[i][idx] = field.DefaultValueInterface
|
||||
field.Set(rv, field.DefaultValueInterface)
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
field.Set(rv, curTime)
|
||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 {
|
||||
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
|
||||
field.Set(rv, curTime)
|
||||
values.Values[i][idx], _ = field.ValueOf(rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||
defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
|
||||
}
|
||||
defaultValueFieldsHavingValue[field][i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
field.Set(stmt.ReflectValue, curTime)
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 {
|
||||
if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
|
||||
field.Set(stmt.ReflectValue, curTime)
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
||||
if stmt.Schema != nil && len(values.Columns) > 1 {
|
||||
columns := make([]string, 0, len(values.Columns)-1)
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||
columns = append(columns, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = clause.AssignmentColumns(columns)
|
||||
|
||||
// use primary fields as default OnConflict columns
|
||||
if len(onConflict.Columns) == 0 {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
stmt.AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
168
callbacks/delete.go
Normal file
168
callbacks/delete.go
Normal file
@ -0,0 +1,168 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
|
||||
if restricted {
|
||||
for column, v := range selectColumns {
|
||||
if v {
|
||||
if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
|
||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
||||
withoutConditions := false
|
||||
if db.Statement.Unscoped {
|
||||
tx = tx.Unscoped()
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
selects := make([]string, 0, len(db.Statement.Selects))
|
||||
for _, s := range db.Statement.Selects {
|
||||
if s == clause.Associations {
|
||||
selects = append(selects, s)
|
||||
} else if strings.HasPrefix(s, column+".") {
|
||||
selects = append(selects, strings.TrimPrefix(s, column+"."))
|
||||
}
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cond := range queryConds {
|
||||
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||
withoutConditions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !withoutConditions {
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
queryConds = make([]clause.Expression, 0, len(rel.References))
|
||||
foreignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
relForeignKeys = make([]string, 0, len(rel.References))
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
table = rel.JoinTable.Table
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
queryConds = append(queryConds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
|
||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
95
callbacks/helper.go
Normal file
95
callbacks/helper.go
Normal file
@ -0,0 +1,95 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConvertMapToValuesForCreate convert map to values
|
||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||
values.Columns = make([]clause.Column, 0, len(mapValue))
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||
|
||||
var keys = make([]string, 0, len(mapValue))
|
||||
for k := range mapValue {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
value := mapValue[k]
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||
if len(values.Values) == 0 {
|
||||
values.Values = [][]interface{}{{}}
|
||||
}
|
||||
|
||||
values.Values[0] = append(values.Values[0], value)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
|
||||
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
|
||||
var (
|
||||
columns = make([]string, 0, len(mapValues))
|
||||
)
|
||||
|
||||
// when the length of mapValues,return directly here
|
||||
// no need to call stmt.SelectAndOmitColumns method
|
||||
if len(mapValues) == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
result = make(map[string][]interface{}, len(mapValues))
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[k]; !ok {
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
result[k] = make([]interface{}, len(mapValues))
|
||||
columns = append(columns, k)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result[k][idx] = v
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(columns)
|
||||
values.Values = make([][]interface{}, len(mapValues))
|
||||
values.Columns = make([]clause.Column, len(columns))
|
||||
for idx, column := range columns {
|
||||
values.Columns[idx] = clause.Column{Name: column}
|
||||
|
||||
for i, v := range result[column] {
|
||||
if len(values.Values[i]) == 0 {
|
||||
values.Values[i] = make([]interface{}, len(columns))
|
||||
}
|
||||
|
||||
values.Values[i][idx] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
39
callbacks/interfaces.go
Normal file
39
callbacks/interfaces.go
Normal file
@ -0,0 +1,39 @@
|
||||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
||||
164
callbacks/preload.go
Normal file
164
callbacks/preload.go
Normal file
@ -0,0 +1,164 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
|
||||
var (
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
foreignValues [][]interface{}
|
||||
identityMap = map[string][]reflect.Value{}
|
||||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinForeignKeys = make([]string, 0, len(rel.References))
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
|
||||
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||
if len(joinForeignValues) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
||||
for i := 0; i < joinResults.Len(); i++ {
|
||||
for idx, field := range joinForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
||||
}
|
||||
|
||||
for idx, field := range joinRelForeignFields {
|
||||
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
joinKey := utils.ToStringKey(joinFieldValues...)
|
||||
identityMap[joinKey] = append(identityMap[joinKey], results...)
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
|
||||
if len(foreignValues) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// nested preload
|
||||
for p, pvs := range preloads {
|
||||
tx = tx.Preload(p, pvs...)
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
} else {
|
||||
inlineConds = append(inlineConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
// clean up old values before preloading
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
default:
|
||||
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
|
||||
default:
|
||||
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(elem)
|
||||
}
|
||||
|
||||
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(data)
|
||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
rel.Field.Set(data, reflectResults.Index(i).Interface())
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
|
||||
} else {
|
||||
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
228
callbacks/query.go
Normal file
228
callbacks/query.go
Normal file
@ -0,0 +1,228 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
gorm.Scan(rows, db, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(100)
|
||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||
var conds []clause.Expression
|
||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||
for idx, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||
} else {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
||||
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
||||
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
||||
for _, dbName := range db.Statement.Schema.DBNames {
|
||||
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
||||
queryFields := db.QueryFields
|
||||
if !queryFields {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
||||
case reflect.Slice:
|
||||
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
||||
}
|
||||
}
|
||||
|
||||
if queryFields {
|
||||
stmt := gorm.Statement{DB: db}
|
||||
// smaller struct
|
||||
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
|
||||
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||
|
||||
for idx, dbName := range stmt.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
if len(db.Statement.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
joins := []clause.Join{}
|
||||
|
||||
if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
joins = fromClause.Joins
|
||||
}
|
||||
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema == nil {
|
||||
joins = append(joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
})
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
joins = append(joins, clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
})
|
||||
} else {
|
||||
joins = append(joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Joins = nil
|
||||
db.Statement.AddClause(clause.From{Joins: joins})
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
||||
if rel.Schema == db.Statement.Schema {
|
||||
if _, ok := preloadMap[rel.Name]; !ok {
|
||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
16
callbacks/raw.go
Normal file
16
callbacks/raw.go
Normal file
@ -0,0 +1,16 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
if db.Error == nil && !db.DryRun {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
} else {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
}
|
||||
}
|
||||
21
callbacks/row.go
Normal file
21
callbacks/row.go
Normal file
@ -0,0 +1,21 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
|
||||
if !db.DryRun {
|
||||
if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
|
||||
db.RowsAffected = -1
|
||||
}
|
||||
}
|
||||
}
|
||||
31
callbacks/transaction.go
Normal file
31
callbacks/transaction.go
Normal file
@ -0,0 +1,31 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BeginTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction {
|
||||
if tx := db.Begin(); tx.Error == nil {
|
||||
db.Statement.ConnPool = tx.Statement.ConnPool
|
||||
db.InstanceSet("gorm:started_transaction", true)
|
||||
} else if tx.Error == gorm.ErrInvalidTransaction {
|
||||
tx.Error = nil
|
||||
} else {
|
||||
db.Error = tx.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction {
|
||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db.Error == nil {
|
||||
db.Commit()
|
||||
} else {
|
||||
db.Rollback()
|
||||
}
|
||||
db.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
}
|
||||
}
|
||||
261
callbacks/update.go
Normal file
261
callbacks/update.go
Normal file
@ -0,0 +1,261 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Update(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
if db.Statement.Schema != nil && !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.String() == "" {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if err == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToAssignments convert to update assignments
|
||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
field.Set(stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.ReflectValue, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
updatingValue := reflect.ValueOf(stmt.Dest)
|
||||
for updatingValue.Kind() == reflect.Ptr {
|
||||
updatingValue = updatingValue.Elem()
|
||||
}
|
||||
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var primaryKeyExprs []clause.Expression
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
|
||||
var notZero bool
|
||||
for idx, field := range stmt.Schema.PrimaryFields {
|
||||
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
|
||||
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
|
||||
notZero = notZero || !isZero
|
||||
}
|
||||
if notZero {
|
||||
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
|
||||
}
|
||||
}
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch value := updatingValue.Interface().(type) {
|
||||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
keys := make([]string, 0, len(value))
|
||||
for k := range value {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
kv := value[k]
|
||||
if _, ok := kv.(*gorm.DB); ok {
|
||||
kv = []interface{}{kv}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||
}
|
||||
}
|
||||
|
||||
if !stmt.SkipHooks && stmt.Schema != nil {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
|
||||
now := stmt.DB.NowFunc()
|
||||
assignValue(field, now)
|
||||
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
} else if field.GORMDataType == schema.Time {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
} else {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||
value, isZero := field.ValueOf(updatingValue)
|
||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
} else if field.GORMDataType == schema.Time {
|
||||
value = stmt.DB.NowFunc()
|
||||
} else {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
}
|
||||
isZero = false
|
||||
}
|
||||
|
||||
if ok || !isZero {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignValue(field, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if value, isZero := field.ValueOf(updatingValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@ -1,177 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func (s *Product) BeforeCreate() (err error) {
|
||||
if s.Code == "Invalid" {
|
||||
err = errors.New("invalid product")
|
||||
}
|
||||
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeUpdate() (err error) {
|
||||
if s.Code == "dont_update" {
|
||||
err = errors.New("can't update")
|
||||
}
|
||||
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeSave() (err error) {
|
||||
if s.Code == "dont_save" {
|
||||
err = errors.New("can't save")
|
||||
}
|
||||
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterFind() {
|
||||
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterCreate(tx *gorm.DB) {
|
||||
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
|
||||
}
|
||||
|
||||
func (s *Product) AfterUpdate() {
|
||||
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||
}
|
||||
|
||||
func (s *Product) AfterSave() (err error) {
|
||||
if s.Code == "after_save_error" {
|
||||
err = errors.New("can't save")
|
||||
}
|
||||
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) BeforeDelete() (err error) {
|
||||
if s.Code == "dont_delete" {
|
||||
err = errors.New("can't delete")
|
||||
}
|
||||
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) AfterDelete() (err error) {
|
||||
if s.Code == "after_delete_error" {
|
||||
err = errors.New("can't delete")
|
||||
}
|
||||
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Product) GetCallTimes() []int64 {
|
||||
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
||||
}
|
||||
|
||||
func TestRunCallbacks(t *testing.T) {
|
||||
p := Product{Code: "unique_code", Price: 100}
|
||||
DB.Save(&p)
|
||||
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
||||
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
DB.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
||||
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
p.Price = 200
|
||||
DB.Save(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
||||
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
var products []Product
|
||||
DB.Find(&products, "code = ?", "unique_code")
|
||||
if products[0].AfterFindCallTimes != 2 {
|
||||
t.Errorf("AfterFind callbacks should work with slice")
|
||||
}
|
||||
|
||||
DB.Where("Code = ?", "unique_code").First(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
||||
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
DB.Delete(&p)
|
||||
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
||||
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||
}
|
||||
|
||||
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
||||
t.Errorf("Can't find a deleted record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksWithErrors(t *testing.T) {
|
||||
p := Product{Code: "Invalid", Price: 100}
|
||||
if DB.Save(&p).Error == nil {
|
||||
t.Errorf("An error from before create callbacks happened when create with invalid value")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
||||
t.Errorf("Should not save record that have errors")
|
||||
}
|
||||
|
||||
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
||||
t.Errorf("An error from after create callbacks happened when create with invalid value")
|
||||
}
|
||||
|
||||
p2 := Product{Code: "update_callback", Price: 100}
|
||||
DB.Save(&p2)
|
||||
|
||||
p2.Code = "dont_update"
|
||||
if DB.Save(&p2).Error == nil {
|
||||
t.Errorf("An error from before update callbacks happened when update with invalid value")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
||||
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||
}
|
||||
|
||||
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
||||
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||
}
|
||||
|
||||
p2.Code = "dont_save"
|
||||
if DB.Save(&p2).Error == nil {
|
||||
t.Errorf("An error from before save callbacks happened when update with invalid value")
|
||||
}
|
||||
|
||||
p3 := Product{Code: "dont_delete", Price: 100}
|
||||
DB.Save(&p3)
|
||||
if DB.Delete(&p3).Error == nil {
|
||||
t.Errorf("An error from before delete callbacks happened when delete")
|
||||
}
|
||||
|
||||
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||
t.Errorf("An error from before delete callbacks happened")
|
||||
}
|
||||
|
||||
p4 := Product{Code: "after_save_error", Price: 100}
|
||||
DB.Save(&p4)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
||||
t.Errorf("Record should be reverted if get an error in after save callback")
|
||||
}
|
||||
|
||||
p5 := Product{Code: "after_delete_error", Price: 100}
|
||||
DB.Save(&p5)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Errorf("Record should be found")
|
||||
}
|
||||
|
||||
DB.Delete(&p5)
|
||||
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||
}
|
||||
}
|
||||
293
chainable_api.go
Normal file
293
chainable_api.go
Normal file
@ -0,0 +1,293 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
return
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
|
||||
for _, cond := range conds {
|
||||
if c, ok := cond.(clause.Interface); ok {
|
||||
tx.Statement.AddClause(c)
|
||||
} else if optimizer, ok := cond.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(tx.Statement)
|
||||
} else {
|
||||
whereConds = append(whereConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(whereConds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
||||
tx.Statement.Table = results[1]
|
||||
return
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = tables[1]
|
||||
return
|
||||
}
|
||||
|
||||
tx.Statement.Table = name
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
if len(args) > 0 {
|
||||
tx = tx.Select(args[0], args[1:]...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := query.(type) {
|
||||
case []string:
|
||||
tx.Statement.Selects = v
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
return
|
||||
}
|
||||
}
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
case string:
|
||||
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.NamedExpr{SQL: v, Vars: args},
|
||||
})
|
||||
} else {
|
||||
tx.Statement.Selects = []string{v}
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when creating, updating and querying
|
||||
func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Having: tx.Statement.BuildCondition(query, args...),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := value.(type) {
|
||||
case clause.OrderByColumn:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{v},
|
||||
})
|
||||
default:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{
|
||||
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
|
||||
}},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
return
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
return tx
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
tx.Statement.Preloads = map[string][]interface{}{}
|
||||
}
|
||||
tx.Statement.Preloads[query] = args
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
return
|
||||
}
|
||||
56
clause/benchmarks_test.go
Normal file
56
clause/benchmarks_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func BenchmarkSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}}
|
||||
|
||||
for _, clause := range clauses {
|
||||
stmt.AddClause(clause)
|
||||
}
|
||||
|
||||
stmt.Build("SELECT", "FROM", "WHERE")
|
||||
_ = stmt.SQL.String()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{
|
||||
clause.Select{}, clause.From{},
|
||||
clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
|
||||
clause.Gt{Column: "age", Value: 18},
|
||||
clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}),
|
||||
}},
|
||||
clause.Where{Exprs: []clause.Expression{
|
||||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||
}},
|
||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||
clause.Limit{Limit: 10, Offset: 20},
|
||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||
}
|
||||
|
||||
for _, clause := range clauses {
|
||||
stmt.AddClause(clause)
|
||||
}
|
||||
|
||||
stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY")
|
||||
_ = stmt.SQL.String()
|
||||
}
|
||||
}
|
||||
88
clause/clause.go
Normal file
88
clause/clause.go
Normal file
@ -0,0 +1,88 @@
|
||||
package clause
|
||||
|
||||
// Interface clause interface
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Build(Builder)
|
||||
MergeClause(*Clause)
|
||||
}
|
||||
|
||||
// ClauseBuilder clause builder, allows to customize how to build clause
|
||||
type ClauseBuilder func(Clause, Builder)
|
||||
|
||||
type Writer interface {
|
||||
WriteByte(byte) error
|
||||
WriteString(string) (int, error)
|
||||
}
|
||||
|
||||
// Builder builder interface
|
||||
type Builder interface {
|
||||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
}
|
||||
|
||||
// Clause
|
||||
type Clause struct {
|
||||
Name string // WHERE
|
||||
BeforeExpression Expression
|
||||
AfterNameExpression Expression
|
||||
AfterExpression Expression
|
||||
Expression Expression
|
||||
Builder ClauseBuilder
|
||||
}
|
||||
|
||||
// Build build clause
|
||||
func (c Clause) Build(builder Builder) {
|
||||
if c.Builder != nil {
|
||||
c.Builder(c, builder)
|
||||
} else if c.Expression != nil {
|
||||
if c.BeforeExpression != nil {
|
||||
c.BeforeExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.Name != "" {
|
||||
builder.WriteString(c.Name)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.AfterNameExpression != nil {
|
||||
c.AfterNameExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
c.Expression.Build(builder)
|
||||
|
||||
if c.AfterExpression != nil {
|
||||
builder.WriteByte(' ')
|
||||
c.AfterExpression.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
PrimaryKey string = "~~~py~~~" // primary key
|
||||
CurrentTable string = "~~~ct~~~" // current table
|
||||
Associations string = "~~~as~~~" // associations
|
||||
)
|
||||
|
||||
var (
|
||||
currentTable = Table{Name: CurrentTable}
|
||||
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
|
||||
)
|
||||
|
||||
// Column quote with name
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
|
||||
// Table quote with name
|
||||
type Table struct {
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
43
clause/clause_test.go
Normal file
43
clause/clause_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
var db, _ = gorm.Open(tests.DummyDialector{}, nil)
|
||||
|
||||
func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) {
|
||||
var (
|
||||
buildNames []string
|
||||
buildNamesMap = map[string]bool{}
|
||||
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
)
|
||||
|
||||
for _, c := range clauses {
|
||||
if _, ok := buildNamesMap[c.Name()]; !ok {
|
||||
buildNames = append(buildNames, c.Name())
|
||||
buildNamesMap[c.Name()] = true
|
||||
}
|
||||
|
||||
stmt.AddClause(c)
|
||||
}
|
||||
|
||||
stmt.Build(buildNames...)
|
||||
|
||||
if strings.TrimSpace(stmt.SQL.String()) != result {
|
||||
t.Errorf("SQL expects %v got %v", result, stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(stmt.Vars, vars) {
|
||||
t.Errorf("Vars expects %+v got %v", stmt.Vars, vars)
|
||||
}
|
||||
}
|
||||
23
clause/delete.go
Normal file
23
clause/delete.go
Normal file
@ -0,0 +1,23 @@
|
||||
package clause
|
||||
|
||||
type Delete struct {
|
||||
Modifier string
|
||||
}
|
||||
|
||||
func (d Delete) Name() string {
|
||||
return "DELETE"
|
||||
}
|
||||
|
||||
func (d Delete) Build(builder Builder) {
|
||||
builder.WriteString("DELETE")
|
||||
|
||||
if d.Modifier != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(d.Modifier)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Delete) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = d
|
||||
}
|
||||
31
clause/delete_test.go
Normal file
31
clause/delete_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Delete{}, clause.From{}},
|
||||
"DELETE FROM `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}},
|
||||
"DELETE LOW_PRIORITY FROM `users`", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
344
clause/expression.go
Normal file
344
clause/expression.go
Normal file
@ -0,0 +1,344 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
Build(builder Builder)
|
||||
}
|
||||
|
||||
// NegationExpressionBuilder negation expression builder
|
||||
type NegationExpressionBuilder interface {
|
||||
NegationBuild(builder Builder)
|
||||
}
|
||||
|
||||
// Expr raw expression
|
||||
type Expr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
WithoutParentheses bool
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr Expr) Build(builder Builder) {
|
||||
var (
|
||||
afterParenthesis bool
|
||||
idx int
|
||||
)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis || expr.WithoutParentheses {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NamedExpr raw expression for named expr
|
||||
type NamedExpr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr NamedExpr) Build(builder Builder) {
|
||||
var (
|
||||
idx int
|
||||
inName bool
|
||||
afterParenthesis bool
|
||||
namedMap = make(map[string]interface{}, len(expr.Vars))
|
||||
)
|
||||
|
||||
for _, v := range expr.Vars {
|
||||
switch value := v.(type) {
|
||||
case sql.NamedArg:
|
||||
namedMap[value.Name] = value.Value
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
namedMap[k] = v
|
||||
}
|
||||
default:
|
||||
var appendFieldsToMap func(reflect.Value)
|
||||
appendFieldsToMap = func(reflectValue reflect.Value) {
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
modelType := reflectValue.Type()
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
|
||||
|
||||
if fieldStruct.Anonymous {
|
||||
appendFieldsToMap(reflectValue.Field(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
appendFieldsToMap(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
|
||||
name := make([]byte, 0, 10)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = []byte{}
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
inName = false
|
||||
}
|
||||
|
||||
afterParenthesis = false
|
||||
builder.WriteByte(v)
|
||||
} else if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else if inName {
|
||||
name = append(name, v)
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if inName {
|
||||
builder.AddVar(builder, namedMap[string(name)])
|
||||
}
|
||||
}
|
||||
|
||||
// IN Whether a value is within a set of values
|
||||
type IN struct {
|
||||
Column interface{}
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func (in IN) Build(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
builder.WriteString(" IN (NULL)")
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteString(" IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
func (in IN) NegationBuild(builder Builder) {
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteQuoted(in.Column)
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteQuoted(in.Column)
|
||||
builder.WriteString(" NOT IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// Eq equal to for where
|
||||
type Eq struct {
|
||||
Column interface{}
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (eq Eq) Build(builder Builder) {
|
||||
builder.WriteQuoted(eq.Column)
|
||||
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
} else {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, eq.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (eq Eq) NegationBuild(builder Builder) {
|
||||
Neq(eq).Build(builder)
|
||||
}
|
||||
|
||||
// Neq not equal to for where
|
||||
type Neq Eq
|
||||
|
||||
func (neq Neq) Build(builder Builder) {
|
||||
builder.WriteQuoted(neq.Column)
|
||||
|
||||
if eqNil(neq.Value) {
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
} else {
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, neq.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (neq Neq) NegationBuild(builder Builder) {
|
||||
Eq(neq).Build(builder)
|
||||
}
|
||||
|
||||
// Gt greater than for where
|
||||
type Gt Eq
|
||||
|
||||
func (gt Gt) Build(builder Builder) {
|
||||
builder.WriteQuoted(gt.Column)
|
||||
builder.WriteString(" > ")
|
||||
builder.AddVar(builder, gt.Value)
|
||||
}
|
||||
|
||||
func (gt Gt) NegationBuild(builder Builder) {
|
||||
Lte(gt).Build(builder)
|
||||
}
|
||||
|
||||
// Gte greater than or equal to for where
|
||||
type Gte Eq
|
||||
|
||||
func (gte Gte) Build(builder Builder) {
|
||||
builder.WriteQuoted(gte.Column)
|
||||
builder.WriteString(" >= ")
|
||||
builder.AddVar(builder, gte.Value)
|
||||
}
|
||||
|
||||
func (gte Gte) NegationBuild(builder Builder) {
|
||||
Lt(gte).Build(builder)
|
||||
}
|
||||
|
||||
// Lt less than for where
|
||||
type Lt Eq
|
||||
|
||||
func (lt Lt) Build(builder Builder) {
|
||||
builder.WriteQuoted(lt.Column)
|
||||
builder.WriteString(" < ")
|
||||
builder.AddVar(builder, lt.Value)
|
||||
}
|
||||
|
||||
func (lt Lt) NegationBuild(builder Builder) {
|
||||
Gte(lt).Build(builder)
|
||||
}
|
||||
|
||||
// Lte less than or equal to for where
|
||||
type Lte Eq
|
||||
|
||||
func (lte Lte) Build(builder Builder) {
|
||||
builder.WriteQuoted(lte.Column)
|
||||
builder.WriteString(" <= ")
|
||||
builder.AddVar(builder, lte.Value)
|
||||
}
|
||||
|
||||
func (lte Lte) NegationBuild(builder Builder) {
|
||||
Gt(lte).Build(builder)
|
||||
}
|
||||
|
||||
// Like whether string matches regular expression
|
||||
type Like Eq
|
||||
|
||||
func (like Like) Build(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func (like Like) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" NOT LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func eqNil(value interface{}) bool {
|
||||
if valuer, ok := value.(driver.Valuer); ok {
|
||||
value, _ = valuer.Value()
|
||||
}
|
||||
|
||||
return value == nil || eqNilReflect(value)
|
||||
}
|
||||
|
||||
func eqNilReflect(value interface{}) bool {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
|
||||
}
|
||||
153
clause/expression_test.go
Normal file
153
clause/expression_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestExpr(t *testing.T) {
|
||||
results := []struct {
|
||||
SQL string
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{{
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
|
||||
Result: "create table `users` (`id` int, `name` text)",
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedExpr(t *testing.T) {
|
||||
type Base struct {
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type NamedArgument struct {
|
||||
Name1 string
|
||||
Base
|
||||
}
|
||||
|
||||
results := []struct {
|
||||
SQL string
|
||||
Result string
|
||||
Vars []interface{}
|
||||
ExpectedVars []interface{}
|
||||
}{{
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
|
||||
Result: "create table `users` (`id` int, `name` text)",
|
||||
}, {
|
||||
SQL: "name1 = @name AND name2 = @name",
|
||||
Vars: []interface{}{sql.Named("name", "jinzhu")},
|
||||
Result: "name1 = ? AND name2 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||
}, {
|
||||
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
|
||||
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
|
||||
Result: "name1 = ? AND name2 = ? AND name3 = ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
|
||||
}, {
|
||||
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
|
||||
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
|
||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||
}, {
|
||||
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist",
|
||||
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
|
||||
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?",
|
||||
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil},
|
||||
}, {
|
||||
SQL: "create table ? (? ?, ? ?)",
|
||||
Vars: []interface{}{},
|
||||
Result: "create table ? (? ?, ? ?)",
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
|
||||
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpression(t *testing.T) {
|
||||
column := "column-name"
|
||||
results := []struct {
|
||||
Expressions []clause.Expression
|
||||
Result string
|
||||
}{{
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: column, Value: "column-value"},
|
||||
},
|
||||
Result: "`column-name` = ?",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: column, Value: nil},
|
||||
clause.Eq{Column: column, Value: (*string)(nil)},
|
||||
clause.Eq{Column: column, Value: (*int)(nil)},
|
||||
clause.Eq{Column: column, Value: (*bool)(nil)},
|
||||
clause.Eq{Column: column, Value: (interface{})(nil)},
|
||||
clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}},
|
||||
},
|
||||
Result: "`column-name` IS NULL",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Neq{Column: column, Value: "column-value"},
|
||||
},
|
||||
Result: "`column-name` <> ?",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Neq{Column: column, Value: nil},
|
||||
clause.Neq{Column: column, Value: (*string)(nil)},
|
||||
clause.Neq{Column: column, Value: (*int)(nil)},
|
||||
clause.Neq{Column: column, Value: (*bool)(nil)},
|
||||
clause.Neq{Column: column, Value: (interface{})(nil)},
|
||||
},
|
||||
Result: "`column-name` IS NOT NULL",
|
||||
}}
|
||||
|
||||
for idx, result := range results {
|
||||
for idy, expression := range result.Expressions {
|
||||
t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
expression.Build(stmt)
|
||||
if stmt.SQL.String() != result.Result {
|
||||
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
37
clause/from.go
Normal file
37
clause/from.go
Normal file
@ -0,0 +1,37 @@
|
||||
package clause
|
||||
|
||||
// From from clause
|
||||
type From struct {
|
||||
Tables []Table
|
||||
Joins []Join
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (from From) Name() string {
|
||||
return "FROM"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (from From) Build(builder Builder) {
|
||||
if len(from.Tables) > 0 {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(table)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(currentTable)
|
||||
}
|
||||
|
||||
for _, join := range from.Joins {
|
||||
builder.WriteByte(' ')
|
||||
join.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge from clause
|
||||
func (from From) MergeClause(clause *Clause) {
|
||||
clause.Expression = from
|
||||
}
|
||||
75
clause/from_test.go
Normal file
75
clause/from_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestFrom(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}},
|
||||
"SELECT * FROM `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Select{}, clause.From{
|
||||
Tables: []clause.Table{{Name: "users"}},
|
||||
Joins: []clause.Join{
|
||||
{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "articles"},
|
||||
ON: clause.Where{
|
||||
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Select{}, clause.From{
|
||||
Tables: []clause.Table{{Name: "users"}},
|
||||
Joins: []clause.Join{
|
||||
{
|
||||
Type: clause.RightJoin,
|
||||
Table: clause.Table{Name: "profiles"},
|
||||
ON: clause.Where{
|
||||
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, clause.From{
|
||||
Joins: []clause.Join{
|
||||
{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "articles"},
|
||||
ON: clause.Where{
|
||||
[]clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
}, {
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "companies"},
|
||||
Using: []string{"company_name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
48
clause/group_by.go
Normal file
48
clause/group_by.go
Normal file
@ -0,0 +1,48 @@
|
||||
package clause
|
||||
|
||||
// GroupBy group by clause
|
||||
type GroupBy struct {
|
||||
Columns []Column
|
||||
Having []Expression
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (groupBy GroupBy) Name() string {
|
||||
return "GROUP BY"
|
||||
}
|
||||
|
||||
// Build build group by clause
|
||||
func (groupBy GroupBy) Build(builder Builder) {
|
||||
for idx, column := range groupBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
|
||||
if len(groupBy.Having) > 0 {
|
||||
builder.WriteString(" HAVING ")
|
||||
Where{Exprs: groupBy.Having}.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge group by clause
|
||||
func (groupBy GroupBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(GroupBy); ok {
|
||||
copiedColumns := make([]Column, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
|
||||
|
||||
copiedHaving := make([]Expression, len(v.Having))
|
||||
copy(copiedHaving, v.Having)
|
||||
groupBy.Having = append(copiedHaving, groupBy.Having...)
|
||||
}
|
||||
clause.Expression = groupBy
|
||||
|
||||
if len(groupBy.Columns) == 0 {
|
||||
clause.Name = ""
|
||||
} else {
|
||||
clause.Name = groupBy.Name()
|
||||
}
|
||||
}
|
||||
40
clause/group_by_test.go
Normal file
40
clause/group_by_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestGroupBy(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: "role"}},
|
||||
Having: []clause.Expression{clause.Eq{"role", "admin"}},
|
||||
}},
|
||||
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: "role"}},
|
||||
Having: []clause.Expression{clause.Eq{"role", "admin"}},
|
||||
}, clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: "gender"}},
|
||||
Having: []clause.Expression{clause.Neq{"gender", "U"}},
|
||||
}},
|
||||
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
39
clause/insert.go
Normal file
39
clause/insert.go
Normal file
@ -0,0 +1,39 @@
|
||||
package clause
|
||||
|
||||
type Insert struct {
|
||||
Table Table
|
||||
Modifier string
|
||||
}
|
||||
|
||||
// Name insert clause name
|
||||
func (insert Insert) Name() string {
|
||||
return "INSERT"
|
||||
}
|
||||
|
||||
// Build build insert clause
|
||||
func (insert Insert) Build(builder Builder) {
|
||||
if insert.Modifier != "" {
|
||||
builder.WriteString(insert.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(insert.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge insert clause
|
||||
func (insert Insert) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Insert); ok {
|
||||
if insert.Modifier == "" {
|
||||
insert.Modifier = v.Modifier
|
||||
}
|
||||
if insert.Table.Name == "" {
|
||||
insert.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = insert
|
||||
}
|
||||
35
clause/insert_test.go
Normal file
35
clause/insert_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Insert{}},
|
||||
"INSERT INTO `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}},
|
||||
"INSERT LOW_PRIORITY INTO `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
|
||||
"INSERT LOW_PRIORITY INTO `products`", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
47
clause/joins.go
Normal file
47
clause/joins.go
Normal file
@ -0,0 +1,47 @@
|
||||
package clause
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
CrossJoin JoinType = "CROSS"
|
||||
InnerJoin JoinType = "INNER"
|
||||
LeftJoin JoinType = "LEFT"
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
// Join join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
ON Where
|
||||
Using []string
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
if join.Expression != nil {
|
||||
join.Expression.Build(builder)
|
||||
} else {
|
||||
if join.Type != "" {
|
||||
builder.WriteString(string(join.Type))
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("JOIN ")
|
||||
builder.WriteQuoted(join.Table)
|
||||
|
||||
if len(join.ON.Exprs) > 0 {
|
||||
builder.WriteString(" ON ")
|
||||
join.ON.Build(builder)
|
||||
} else if len(join.Using) > 0 {
|
||||
builder.WriteString(" USING (")
|
||||
for idx, c := range join.Using {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(c)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
48
clause/limit.go
Normal file
48
clause/limit.go
Normal file
@ -0,0 +1,48 @@
|
||||
package clause
|
||||
|
||||
import "strconv"
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (limit Limit) Name() string {
|
||||
return "LIMIT"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit > 0 {
|
||||
builder.WriteString(" ")
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
if limit.Offset == 0 && v.Offset > 0 {
|
||||
limit.Offset = v.Offset
|
||||
} else if limit.Offset < 0 {
|
||||
limit.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = limit
|
||||
}
|
||||
58
clause/limit_test.go
Normal file
58
clause/limit_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||
Limit: 10,
|
||||
Offset: 20,
|
||||
}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||
"SELECT * FROM `users` OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT 10", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
31
clause/locking.go
Normal file
31
clause/locking.go
Normal file
@ -0,0 +1,31 @@
|
||||
package clause
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
Options string
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (locking Locking) Name() string {
|
||||
return "FOR"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (locking Locking) Build(builder Builder) {
|
||||
builder.WriteString(locking.Strength)
|
||||
if locking.Table.Name != "" {
|
||||
builder.WriteString(" OF ")
|
||||
builder.WriteQuoted(locking.Table)
|
||||
}
|
||||
|
||||
if locking.Options != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(locking.Options)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (locking Locking) MergeClause(clause *Clause) {
|
||||
clause.Expression = locking
|
||||
}
|
||||
35
clause/locking_test.go
Normal file
35
clause/locking_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestLocking(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
|
||||
"SELECT * FROM `users` FOR UPDATE", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
|
||||
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
52
clause/on_conflict.go
Normal file
52
clause/on_conflict.go
Normal file
@ -0,0 +1,52 @@
|
||||
package clause
|
||||
|
||||
type OnConflict struct {
|
||||
Columns []Column
|
||||
Where Where
|
||||
OnConstraint string
|
||||
DoNothing bool
|
||||
DoUpdates Set
|
||||
UpdateAll bool
|
||||
}
|
||||
|
||||
func (OnConflict) Name() string {
|
||||
return "ON CONFLICT"
|
||||
}
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
builder.WriteString("DO NOTHING")
|
||||
} else {
|
||||
builder.WriteString("DO UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(builder)
|
||||
}
|
||||
|
||||
if len(onConflict.Where.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.Where.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge onConflict clauses
|
||||
func (onConflict OnConflict) MergeClause(clause *Clause) {
|
||||
clause.Expression = onConflict
|
||||
}
|
||||
54
clause/order_by.go
Normal file
54
clause/order_by.go
Normal file
@ -0,0 +1,54 @@
|
||||
package clause
|
||||
|
||||
type OrderByColumn struct {
|
||||
Column Column
|
||||
Desc bool
|
||||
Reorder bool
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
Columns []OrderByColumn
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (orderBy OrderBy) Name() string {
|
||||
return "ORDER BY"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (orderBy OrderBy) Build(builder Builder) {
|
||||
if orderBy.Expression != nil {
|
||||
orderBy.Expression.Build(builder)
|
||||
} else {
|
||||
for idx, column := range orderBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column.Column)
|
||||
if column.Desc {
|
||||
builder.WriteString(" DESC")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (orderBy OrderBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(OrderBy); ok {
|
||||
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
|
||||
if orderBy.Columns[i].Reorder {
|
||||
orderBy.Columns = orderBy.Columns[i:]
|
||||
clause.Expression = orderBy
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
copiedColumns := make([]OrderByColumn, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = orderBy
|
||||
}
|
||||
57
clause/order_by_test.go
Normal file
57
clause/order_by_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestOrderBy(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}},
|
||||
"SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Select{}, clause.From{}, clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}, clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}},
|
||||
},
|
||||
},
|
||||
"SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Select{}, clause.From{}, clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}, clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}},
|
||||
},
|
||||
},
|
||||
"SELECT * FROM `users` ORDER BY `name`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Select{}, clause.From{}, clause.OrderBy{
|
||||
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
|
||||
},
|
||||
},
|
||||
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
30
clause/returning.go
Normal file
30
clause/returning.go
Normal file
@ -0,0 +1,30 @@
|
||||
package clause
|
||||
|
||||
type Returning struct {
|
||||
Columns []Column
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (returning Returning) Name() string {
|
||||
return "RETURNING"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (returning Returning) Build(builder Builder) {
|
||||
for idx, column := range returning.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (returning Returning) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Returning); ok {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = returning
|
||||
}
|
||||
36
clause/returning_test.go
Normal file
36
clause/returning_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestReturning(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING `users`.`id`", nil,
|
||||
}, {
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Returning{
|
||||
[]clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
45
clause/select.go
Normal file
45
clause/select.go
Normal file
@ -0,0 +1,45 @@
|
||||
package clause
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
Distinct bool
|
||||
Columns []Column
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func (s Select) Name() string {
|
||||
return "SELECT"
|
||||
}
|
||||
|
||||
func (s Select) Build(builder Builder) {
|
||||
if len(s.Columns) > 0 {
|
||||
if s.Distinct {
|
||||
builder.WriteString("DISTINCT ")
|
||||
}
|
||||
|
||||
for idx, column := range s.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
func (s Select) MergeClause(clause *Clause) {
|
||||
if s.Expression != nil {
|
||||
if s.Distinct {
|
||||
if expr, ok := s.Expression.(Expr); ok {
|
||||
expr.SQL = "DISTINCT " + expr.SQL
|
||||
clause.Expression = expr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = s.Expression
|
||||
} else {
|
||||
clause.Expression = s
|
||||
}
|
||||
}
|
||||
41
clause/select_test.go
Normal file
41
clause/select_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}},
|
||||
"SELECT * FROM `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{
|
||||
Columns: []clause.Column{clause.PrimaryColumn},
|
||||
}, clause.From{}},
|
||||
"SELECT `users`.`id` FROM `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{
|
||||
Columns: []clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Select{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
}, clause.From{}},
|
||||
"SELECT `name` FROM `users`", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
60
clause/set.go
Normal file
60
clause/set.go
Normal file
@ -0,0 +1,60 @@
|
||||
package clause
|
||||
|
||||
import "sort"
|
||||
|
||||
type Set []Assignment
|
||||
|
||||
type Assignment struct {
|
||||
Column Column
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (set Set) Name() string {
|
||||
return "SET"
|
||||
}
|
||||
|
||||
func (set Set) Build(builder Builder) {
|
||||
if len(set) > 0 {
|
||||
for idx, assignment := range set {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(assignment.Column)
|
||||
builder.WriteByte('=')
|
||||
builder.AddVar(builder, assignment.Value)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(PrimaryColumn)
|
||||
builder.WriteByte('=')
|
||||
builder.WriteQuoted(PrimaryColumn)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge assignments clauses
|
||||
func (set Set) MergeClause(clause *Clause) {
|
||||
copiedAssignments := make([]Assignment, len(set))
|
||||
copy(copiedAssignments, set)
|
||||
clause.Expression = Set(copiedAssignments)
|
||||
}
|
||||
|
||||
func Assignments(values map[string]interface{}) Set {
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
assignments := make([]Assignment, len(keys))
|
||||
for idx, key := range keys {
|
||||
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
|
||||
func AssignmentColumns(values []string) Set {
|
||||
assignments := make([]Assignment, len(values))
|
||||
for idx, value := range values {
|
||||
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
57
clause/set_test.go
Normal file
57
clause/set_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Update{},
|
||||
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
|
||||
},
|
||||
"UPDATE `users` SET `users`.`id`=?", []interface{}{1},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Update{},
|
||||
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
|
||||
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
|
||||
},
|
||||
"UPDATE `users` SET `name`=?", []interface{}{"jinzhu"},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignments(t *testing.T) {
|
||||
set := clause.Assignments(map[string]interface{}{
|
||||
"name": "jinzhu",
|
||||
"age": 18,
|
||||
})
|
||||
|
||||
assignments := []clause.Assignment(set)
|
||||
|
||||
sort.Slice(assignments, func(i, j int) bool {
|
||||
return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0
|
||||
})
|
||||
|
||||
if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 {
|
||||
t.Errorf("invalid assignments, got %v", assignments)
|
||||
}
|
||||
}
|
||||
38
clause/update.go
Normal file
38
clause/update.go
Normal file
@ -0,0 +1,38 @@
|
||||
package clause
|
||||
|
||||
type Update struct {
|
||||
Modifier string
|
||||
Table Table
|
||||
}
|
||||
|
||||
// Name update clause name
|
||||
func (update Update) Name() string {
|
||||
return "UPDATE"
|
||||
}
|
||||
|
||||
// Build build update clause
|
||||
func (update Update) Build(builder Builder) {
|
||||
if update.Modifier != "" {
|
||||
builder.WriteString(update.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if update.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(update.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge update clause
|
||||
func (update Update) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Update); ok {
|
||||
if update.Modifier == "" {
|
||||
update.Modifier = v.Modifier
|
||||
}
|
||||
if update.Table.Name == "" {
|
||||
update.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = update
|
||||
}
|
||||
35
clause/update_test.go
Normal file
35
clause/update_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Update{}},
|
||||
"UPDATE `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}},
|
||||
"UPDATE LOW_PRIORITY `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}},
|
||||
"UPDATE LOW_PRIORITY `products`", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
45
clause/values.go
Normal file
45
clause/values.go
Normal file
@ -0,0 +1,45 @@
|
||||
package clause
|
||||
|
||||
type Values struct {
|
||||
Columns []Column
|
||||
Values [][]interface{}
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (Values) Name() string {
|
||||
return "VALUES"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (values Values) Build(builder Builder) {
|
||||
if len(values.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
|
||||
builder.WriteString(" VALUES ")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteByte('(')
|
||||
builder.AddVar(builder, value...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("DEFAULT VALUES")
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge values clauses
|
||||
func (values Values) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = values
|
||||
}
|
||||
33
clause/values_test.go
Normal file
33
clause/values_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{
|
||||
clause.Insert{},
|
||||
clause.Values{
|
||||
Columns: []clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
|
||||
},
|
||||
},
|
||||
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
177
clause/where.go
Normal file
177
clause/where.go
Normal file
@ -0,0 +1,177 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
if idx != 0 {
|
||||
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buildExprs(where.Exprs, builder, " AND ")
|
||||
}
|
||||
|
||||
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||
wrapInParentheses := false
|
||||
|
||||
for idx, expr := range exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
builder.WriteString(" OR ")
|
||||
} else {
|
||||
builder.WriteString(joinCond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case AndConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
case Expr:
|
||||
sql := strings.ToLower(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
|
||||
}
|
||||
}
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteString(`(`)
|
||||
expr.Build(builder)
|
||||
builder.WriteString(`)`)
|
||||
wrapInParentheses = false
|
||||
} else {
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge where clauses
|
||||
func (where Where) MergeClause(clause *Clause) {
|
||||
if w, ok := clause.Expression.(Where); ok {
|
||||
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
|
||||
copy(exprs, w.Exprs)
|
||||
copy(exprs[len(w.Exprs):], where.Exprs)
|
||||
where.Exprs = exprs
|
||||
}
|
||||
|
||||
clause.Expression = where
|
||||
}
|
||||
|
||||
func And(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
} else if len(exprs) == 1 {
|
||||
return exprs[0]
|
||||
}
|
||||
return AndConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type AndConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (and AndConditions) Build(builder Builder) {
|
||||
if len(and.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(and.Exprs, builder, " AND ")
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(and.Exprs, builder, " AND ")
|
||||
}
|
||||
}
|
||||
|
||||
func Or(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return OrConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type OrConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (or OrConditions) Build(builder Builder) {
|
||||
if len(or.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(or.Exprs, builder, " OR ")
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(or.Exprs, builder, " OR ")
|
||||
}
|
||||
}
|
||||
|
||||
func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type NotConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToLower(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
69
clause/where_test.go
Normal file
69
clause/where_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestWhere(t *testing.T) {
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
|
||||
}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
|
||||
}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
|
||||
checkBuildClauses(t, result.Clauses, result.Result, result.Vars)
|
||||
})
|
||||
}
|
||||
}
|
||||
4
clause/with.go
Normal file
4
clause/with.go
Normal file
@ -0,0 +1,4 @@
|
||||
package clause
|
||||
|
||||
type With struct {
|
||||
}
|
||||
218
create_test.go
218
create_test.go
@ -1,218 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
float := 35.03554004971999
|
||||
now := time.Now()
|
||||
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
|
||||
|
||||
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
||||
t.Error("User should be new record before create")
|
||||
}
|
||||
|
||||
if count := DB.Save(&user).RowsAffected; count != 1 {
|
||||
t.Error("There should be one record be affected when create record")
|
||||
}
|
||||
|
||||
if DB.NewRecord(user) || DB.NewRecord(&user) {
|
||||
t.Error("User should not new record after save")
|
||||
}
|
||||
|
||||
var newUser User
|
||||
DB.First(&newUser, user.Id)
|
||||
|
||||
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
|
||||
t.Errorf("User's PasswordHash should be saved ([]byte)")
|
||||
}
|
||||
|
||||
if newUser.Age != 18 {
|
||||
t.Errorf("User's Age should be saved (int)")
|
||||
}
|
||||
|
||||
if newUser.UserNum != Num(111) {
|
||||
t.Errorf("User's UserNum should be saved (custom type)")
|
||||
}
|
||||
|
||||
if newUser.Latitude != float {
|
||||
t.Errorf("Float64 should not be changed after save")
|
||||
}
|
||||
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Errorf("Should have created_at after create")
|
||||
}
|
||||
|
||||
if newUser.CreatedAt.IsZero() {
|
||||
t.Errorf("Should have created_at after create")
|
||||
}
|
||||
|
||||
DB.Model(user).Update("name", "create_user_new_name")
|
||||
DB.First(&user, user.Id)
|
||||
if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("CreatedAt should not be changed after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithExistingTimestamp(t *testing.T) {
|
||||
user := User{Name: "CreateUserExistingTimestamp"}
|
||||
|
||||
timeA := now.MustParse("2016-01-01")
|
||||
user.CreatedAt = timeA
|
||||
user.UpdatedAt = timeA
|
||||
DB.Save(&user)
|
||||
|
||||
if user.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||
t.Errorf("CreatedAt should not be changed")
|
||||
}
|
||||
|
||||
if user.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||
t.Errorf("UpdatedAt should not be changed")
|
||||
}
|
||||
|
||||
var newUser User
|
||||
DB.First(&newUser, user.Id)
|
||||
|
||||
if newUser.CreatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||
t.Errorf("CreatedAt should not be changed")
|
||||
}
|
||||
|
||||
if newUser.UpdatedAt.UTC().Format(time.RFC3339) != timeA.UTC().Format(time.RFC3339) {
|
||||
t.Errorf("UpdatedAt should not be changed")
|
||||
}
|
||||
}
|
||||
|
||||
type AutoIncrementUser struct {
|
||||
User
|
||||
Sequence uint `gorm:"AUTO_INCREMENT"`
|
||||
}
|
||||
|
||||
func TestCreateWithAutoIncrement(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||
t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
|
||||
}
|
||||
|
||||
DB.AutoMigrate(&AutoIncrementUser{})
|
||||
|
||||
user1 := AutoIncrementUser{}
|
||||
user2 := AutoIncrementUser{}
|
||||
|
||||
DB.Create(&user1)
|
||||
DB.Create(&user2)
|
||||
|
||||
if user2.Sequence-user1.Sequence != 1 {
|
||||
t.Errorf("Auto increment should apply on Sequence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
|
||||
t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column")
|
||||
}
|
||||
|
||||
jt := JoinTable{From: 1, To: 2}
|
||||
err := DB.Create(&jt).Error
|
||||
if err != nil {
|
||||
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||
animal := Animal{Name: "Ferdinand"}
|
||||
if DB.Save(&animal).Error != nil {
|
||||
t.Errorf("No error should happen when create a record without std primary key")
|
||||
}
|
||||
|
||||
if animal.Counter == 0 {
|
||||
t.Errorf("No std primary key should be filled value after create")
|
||||
}
|
||||
|
||||
if animal.Name != "Ferdinand" {
|
||||
t.Errorf("Default value should be overrided")
|
||||
}
|
||||
|
||||
// Test create with default value not overrided
|
||||
an := Animal{From: "nerdz"}
|
||||
|
||||
if DB.Save(&an).Error != nil {
|
||||
t.Errorf("No error should happen when create an record without std primary key")
|
||||
}
|
||||
|
||||
// We must fetch the value again, to have the default fields updated
|
||||
// (We can't do this in the update statements, since sql default can be expressions
|
||||
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
|
||||
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
|
||||
|
||||
if an.Name != "galeone" {
|
||||
t.Errorf("Default value should fill the field. But got %v", an.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymousScanner(t *testing.T) {
|
||||
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
|
||||
DB.Save(&user)
|
||||
|
||||
var user2 User
|
||||
DB.First(&user2, "name = ?", "anonymous_scanner")
|
||||
if user2.Role.Name != "admin" {
|
||||
t.Errorf("Should be able to get anonymous scanner")
|
||||
}
|
||||
|
||||
if !user2.Role.IsAdmin() {
|
||||
t.Errorf("Should be able to get anonymous scanner")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnonymousField(t *testing.T) {
|
||||
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
|
||||
DB.Save(&user)
|
||||
|
||||
var user2 User
|
||||
DB.First(&user2, "name = ?", "anonymous_field")
|
||||
DB.Model(&user2).Related(&user2.Company)
|
||||
if user2.Company.Name != "company" {
|
||||
t.Errorf("Should be able to get anonymous field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithCreate(t *testing.T) {
|
||||
user := getPreparedUser("select_user", "select_with_create")
|
||||
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
|
||||
|
||||
var queryuser User
|
||||
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
|
||||
|
||||
if queryuser.Name != user.Name || queryuser.Age == user.Age {
|
||||
t.Errorf("Should only create users with name column")
|
||||
}
|
||||
|
||||
if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
|
||||
queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
|
||||
t.Errorf("Should only create selected relationships")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOmitWithCreate(t *testing.T) {
|
||||
user := getPreparedUser("omit_user", "omit_with_create")
|
||||
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
|
||||
|
||||
var queryuser User
|
||||
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
|
||||
|
||||
if queryuser.Name == user.Name || queryuser.Age != user.Age {
|
||||
t.Errorf("Should only create users with age column")
|
||||
}
|
||||
|
||||
if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
|
||||
queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
|
||||
t.Errorf("Should not create omitted relationships")
|
||||
}
|
||||
}
|
||||
@ -1,303 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type CustomizeColumn struct {
|
||||
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
|
||||
Name string `gorm:"column:mapped_name"`
|
||||
Date *time.Time `gorm:"column:mapped_time"`
|
||||
}
|
||||
|
||||
// Make sure an ignored field does not interfere with another field's custom
|
||||
// column name that matches the ignored field.
|
||||
type CustomColumnAndIgnoredFieldClash struct {
|
||||
Body string `sql:"-"`
|
||||
RawBody string `gorm:"column:body"`
|
||||
}
|
||||
|
||||
func TestCustomizeColumn(t *testing.T) {
|
||||
col := "mapped_name"
|
||||
DB.DropTable(&CustomizeColumn{})
|
||||
DB.AutoMigrate(&CustomizeColumn{})
|
||||
|
||||
scope := DB.NewScope(&CustomizeColumn{})
|
||||
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||
t.Errorf("CustomizeColumn should have column %s", col)
|
||||
}
|
||||
|
||||
col = "mapped_id"
|
||||
if scope.PrimaryKey() != col {
|
||||
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
|
||||
}
|
||||
|
||||
expected := "foo"
|
||||
now := time.Now()
|
||||
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
|
||||
|
||||
if count := DB.Create(&cc).RowsAffected; count != 1 {
|
||||
t.Error("There should be one record be affected when create record")
|
||||
}
|
||||
|
||||
var cc1 CustomizeColumn
|
||||
DB.First(&cc1, 666)
|
||||
|
||||
if cc1.Name != expected {
|
||||
t.Errorf("Failed to query CustomizeColumn")
|
||||
}
|
||||
|
||||
cc.Name = "bar"
|
||||
DB.Save(&cc)
|
||||
|
||||
var cc2 CustomizeColumn
|
||||
DB.First(&cc2, 666)
|
||||
if cc2.Name != "bar" {
|
||||
t.Errorf("Failed to query CustomizeColumn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
|
||||
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
|
||||
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
|
||||
t.Errorf("Should not raise error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
type CustomizePerson struct {
|
||||
IdPerson string `gorm:"column:idPerson;primary_key:true"`
|
||||
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
|
||||
}
|
||||
|
||||
type CustomizeAccount struct {
|
||||
IdAccount string `gorm:"column:idAccount;primary_key:true"`
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestManyToManyWithCustomizedColumn(t *testing.T) {
|
||||
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
|
||||
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
|
||||
|
||||
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
|
||||
person := CustomizePerson{
|
||||
IdPerson: "person",
|
||||
Accounts: []CustomizeAccount{account},
|
||||
}
|
||||
|
||||
if err := DB.Create(&account).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Create(&person).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
var person1 CustomizePerson
|
||||
scope := DB.NewScope(nil)
|
||||
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
|
||||
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
|
||||
}
|
||||
|
||||
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
|
||||
t.Errorf("should preload correct accounts")
|
||||
}
|
||||
}
|
||||
|
||||
type CustomizeUser struct {
|
||||
gorm.Model
|
||||
Email string `sql:"column:email_address"`
|
||||
}
|
||||
|
||||
type CustomizeInvitation struct {
|
||||
gorm.Model
|
||||
Address string `sql:"column:invitation"`
|
||||
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
|
||||
}
|
||||
|
||||
func TestOneToOneWithCustomizedColumn(t *testing.T) {
|
||||
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
|
||||
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
|
||||
|
||||
user := CustomizeUser{
|
||||
Email: "hello@example.com",
|
||||
}
|
||||
invitation := CustomizeInvitation{
|
||||
Address: "hello@example.com",
|
||||
}
|
||||
|
||||
DB.Create(&user)
|
||||
DB.Create(&invitation)
|
||||
|
||||
var invitation2 CustomizeInvitation
|
||||
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if invitation2.Person.Email != user.Email {
|
||||
t.Errorf("Should preload one to one relation with customize foreign keys")
|
||||
}
|
||||
}
|
||||
|
||||
type PromotionDiscount struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
|
||||
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
|
||||
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
|
||||
}
|
||||
|
||||
type PromotionBenefit struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
PromotionID uint
|
||||
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
|
||||
}
|
||||
|
||||
type PromotionCoupon struct {
|
||||
gorm.Model
|
||||
Code string
|
||||
DiscountID uint
|
||||
Discount PromotionDiscount
|
||||
}
|
||||
|
||||
type PromotionRule struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Begin *time.Time
|
||||
End *time.Time
|
||||
DiscountID uint
|
||||
Discount *PromotionDiscount
|
||||
}
|
||||
|
||||
func TestOneToManyWithCustomizedColumn(t *testing.T) {
|
||||
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
|
||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
|
||||
|
||||
discount := PromotionDiscount{
|
||||
Name: "Happy New Year",
|
||||
Coupons: []*PromotionCoupon{
|
||||
{Code: "newyear1"},
|
||||
{Code: "newyear2"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&discount).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
var discount1 PromotionDiscount
|
||||
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if len(discount.Coupons) != 2 {
|
||||
t.Errorf("should find two coupons")
|
||||
}
|
||||
|
||||
var coupon PromotionCoupon
|
||||
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if coupon.Discount.Name != "Happy New Year" {
|
||||
t.Errorf("should preload discount from coupon")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
|
||||
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
|
||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
|
||||
|
||||
var begin = time.Now()
|
||||
var end = time.Now().Add(24 * time.Hour)
|
||||
discount := PromotionDiscount{
|
||||
Name: "Happy New Year 2",
|
||||
Rule: &PromotionRule{
|
||||
Name: "time_limited",
|
||||
Begin: &begin,
|
||||
End: &end,
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&discount).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
var discount1 PromotionDiscount
|
||||
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
|
||||
t.Errorf("Should be able to preload Rule")
|
||||
}
|
||||
|
||||
var rule PromotionRule
|
||||
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if rule.Discount.Name != "Happy New Year 2" {
|
||||
t.Errorf("should preload discount from rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
||||
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
|
||||
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
|
||||
|
||||
discount := PromotionDiscount{
|
||||
Name: "Happy New Year 3",
|
||||
Benefits: []PromotionBenefit{
|
||||
{Name: "free cod"},
|
||||
{Name: "free shipping"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&discount).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
var discount1 PromotionDiscount
|
||||
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if len(discount.Benefits) != 2 {
|
||||
t.Errorf("should find two benefits")
|
||||
}
|
||||
|
||||
var benefit PromotionBenefit
|
||||
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
|
||||
t.Errorf("no error should happen but got %v", err)
|
||||
}
|
||||
|
||||
if benefit.Discount.Name != "Happy New Year 3" {
|
||||
t.Errorf("should preload discount from coupon")
|
||||
}
|
||||
}
|
||||
|
||||
type SelfReferencingUser struct {
|
||||
gorm.Model
|
||||
Friends []*SelfReferencingUser `gorm:"many2many:UserFriends;AssociationForeignKey:ID=friend_id"`
|
||||
}
|
||||
|
||||
func TestSelfReferencingMany2ManyColumn(t *testing.T) {
|
||||
DB.DropTable(&SelfReferencingUser{}, "UserFriends")
|
||||
DB.AutoMigrate(&SelfReferencingUser{})
|
||||
|
||||
friend := SelfReferencingUser{}
|
||||
if err := DB.Create(&friend).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
user := SelfReferencingUser{
|
||||
Friends: []*SelfReferencingUser{&friend},
|
||||
}
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Errorf("no error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
@ -1,91 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
|
||||
DB.Save(&user1)
|
||||
DB.Save(&user2)
|
||||
|
||||
if err := DB.Delete(&user1).Error; err != nil {
|
||||
t.Errorf("No error should happen when delete a record, err=%s", err)
|
||||
}
|
||||
|
||||
if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
|
||||
t.Errorf("User can't be found after delete")
|
||||
}
|
||||
|
||||
if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
|
||||
t.Errorf("Other users that not deleted should be found-able")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineDelete(t *testing.T) {
|
||||
user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
|
||||
DB.Save(&user1)
|
||||
DB.Save(&user2)
|
||||
|
||||
if DB.Delete(&User{}, user1.Id).Error != nil {
|
||||
t.Errorf("No error should happen when delete a record")
|
||||
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
|
||||
t.Errorf("User can't be found after delete")
|
||||
}
|
||||
|
||||
if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Errorf("No error should happen when delete a record, err=%s", err)
|
||||
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
|
||||
t.Errorf("User can't be found after delete")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDelete(t *testing.T) {
|
||||
type User struct {
|
||||
Id int64
|
||||
Name string
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
DB.AutoMigrate(&User{})
|
||||
|
||||
user := User{Name: "soft_delete"}
|
||||
DB.Save(&user)
|
||||
DB.Delete(&user)
|
||||
|
||||
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
|
||||
}
|
||||
|
||||
DB.Unscoped().Delete(&user)
|
||||
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
|
||||
t.Errorf("Can't find permanently deleted record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDeleteWithCustomizedDeletedAtColumnName(t *testing.T) {
|
||||
creditCard := CreditCard{Number: "411111111234567"}
|
||||
DB.Save(&creditCard)
|
||||
DB.Delete(&creditCard)
|
||||
|
||||
if deletedAtField, ok := DB.NewScope(&CreditCard{}).FieldByName("DeletedAt"); !ok || deletedAtField.DBName != "deleted_time" {
|
||||
t.Errorf("CreditCard's DeletedAt's column name should be `deleted_time`")
|
||||
}
|
||||
|
||||
if DB.First(&CreditCard{}, "number = ?", creditCard.Number).Error == nil {
|
||||
t.Errorf("Can't find a soft deleted record")
|
||||
}
|
||||
|
||||
if err := DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).Error; err != nil {
|
||||
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
|
||||
}
|
||||
|
||||
DB.Unscoped().Delete(&creditCard)
|
||||
if !DB.Unscoped().First(&CreditCard{}, "number = ?", creditCard.Number).RecordNotFound() {
|
||||
t.Errorf("Can't find permanently deleted record")
|
||||
}
|
||||
}
|
||||
116
dialect.go
116
dialect.go
@ -1,116 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Dialect interface contains behaviors that differ across SQL database
|
||||
type Dialect interface {
|
||||
// GetName get dialect's name
|
||||
GetName() string
|
||||
|
||||
// SetDB set db for dialect
|
||||
SetDB(db SQLCommon)
|
||||
|
||||
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||
BindVar(i int) string
|
||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
Quote(key string) string
|
||||
// DataTypeOf return data's sql type
|
||||
DataTypeOf(field *StructField) string
|
||||
|
||||
// HasIndex check has index or not
|
||||
HasIndex(tableName string, indexName string) bool
|
||||
// HasForeignKey check has foreign key or not
|
||||
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||
// RemoveIndex remove index
|
||||
RemoveIndex(tableName string, indexName string) error
|
||||
// HasTable check has table or not
|
||||
HasTable(tableName string) bool
|
||||
// HasColumn check has column or not
|
||||
HasColumn(tableName string, columnName string) bool
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||
|
||||
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
|
||||
BuildForeignKeyName(tableName, field, dest string) string
|
||||
|
||||
// CurrentDatabase return current database name
|
||||
CurrentDatabase() string
|
||||
}
|
||||
|
||||
var dialectsMap = map[string]Dialect{}
|
||||
|
||||
func newDialect(name string, db SQLCommon) Dialect {
|
||||
if value, ok := dialectsMap[name]; ok {
|
||||
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||
dialect.SetDB(db)
|
||||
return dialect
|
||||
}
|
||||
|
||||
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||
commontDialect := &commonDialect{}
|
||||
commontDialect.SetDB(db)
|
||||
return commontDialect
|
||||
}
|
||||
|
||||
// RegisterDialect register new dialect
|
||||
func RegisterDialect(name string, dialect Dialect) {
|
||||
dialectsMap[name] = dialect
|
||||
}
|
||||
|
||||
// ParseFieldStructForDialect get field's sql data type
|
||||
var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||
// Get redirected field type
|
||||
var (
|
||||
reflectType = field.Struct.Type
|
||||
dataType = field.TagSettings["TYPE"]
|
||||
)
|
||||
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
// Get redirected field value
|
||||
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||
|
||||
if gormDataType, ok := fieldValue.Interface().(interface {
|
||||
GormDataType(Dialect) string
|
||||
}); ok {
|
||||
dataType = gormDataType.GormDataType(dialect)
|
||||
}
|
||||
|
||||
// Get scanner's real value
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
fieldValue = value
|
||||
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||
getScannerValue(fieldValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(fieldValue)
|
||||
|
||||
// Default Size
|
||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
} else {
|
||||
size = 255
|
||||
}
|
||||
|
||||
// Default type from tag setting
|
||||
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
return fieldValue, dataType, size, strings.TrimSpace(additionalType)
|
||||
}
|
||||
@ -1,156 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
||||
type DefaultForeignKeyNamer struct {
|
||||
}
|
||||
|
||||
type commonDialect struct {
|
||||
db SQLCommon
|
||||
DefaultForeignKeyNamer
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("common", &commonDialect{})
|
||||
}
|
||||
|
||||
func (commonDialect) GetName() string {
|
||||
return "common"
|
||||
}
|
||||
|
||||
func (s *commonDialect) SetDB(db SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (commonDialect) BindVar(i int) string {
|
||||
return "$$$" // ?
|
||||
}
|
||||
|
||||
func (commonDialect) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (s *commonDialect) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||
sqlType = "INTEGER AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "INTEGER"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||
sqlType = "BIGINT AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "BIGINT"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "FLOAT"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
|
||||
} else {
|
||||
sqlType = "VARCHAR(65532)"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("BINARY(%d)", size)
|
||||
} else {
|
||||
sqlType = "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s commonDialect) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s commonDialect) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
}
|
||||
}
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
|
||||
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
|
||||
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||
return keyName
|
||||
}
|
||||
|
||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
}
|
||||
176
dialect_mysql.go
176
dialect_mysql.go
@ -1,176 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type mysql struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("mysql", &mysql{})
|
||||
}
|
||||
|
||||
func (mysql) GetName() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (mysql) Quote(key string) string {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
// Get Data Type for MySQL Dialect
|
||||
func (s *mysql) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
// MySQL allows only one auto increment column per table, and it must
|
||||
// be a KEY column.
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
|
||||
delete(field.TagSettings, "AUTO_INCREMENT")
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int8:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "tinyint AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "tinyint"
|
||||
}
|
||||
case reflect.Int, reflect.Int16, reflect.Int32:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "int AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int"
|
||||
}
|
||||
case reflect.Uint8:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "tinyint unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "tinyint unsigned"
|
||||
}
|
||||
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "int unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "int unsigned"
|
||||
}
|
||||
case reflect.Int64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "bigint AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||
} else {
|
||||
sqlType = "bigint unsigned"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "double"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "longtext"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
||||
sqlType = "timestamp"
|
||||
} else {
|
||||
sqlType = "timestamp NULL"
|
||||
}
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
sqlType = "longblob"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mysql) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (mysql) SelectFromDummyTable() string {
|
||||
return "FROM DUAL"
|
||||
}
|
||||
|
||||
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
||||
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
||||
if utf8.RuneCountInString(keyName) <= 64 {
|
||||
return keyName
|
||||
}
|
||||
h := sha1.New()
|
||||
h.Write([]byte(keyName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
// sha1 is 40 digits, keep first 24 characters of destination
|
||||
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
|
||||
if len(destRunes) > 24 {
|
||||
destRunes = destRunes[:24]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
||||
}
|
||||
@ -1,132 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type postgres struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("postgres", &postgres{})
|
||||
RegisterDialect("cloudsqlpostgres", &postgres{})
|
||||
}
|
||||
|
||||
func (postgres) GetName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (postgres) BindVar(i int) string {
|
||||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (s *postgres) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "serial"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint32, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "bigserial"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "numeric"
|
||||
case reflect.String:
|
||||
if _, ok := field.TagSettings["SIZE"]; !ok {
|
||||
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||
}
|
||||
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if dataValue.Type().Name() == "Hstore" {
|
||||
sqlType = "hstore"
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
sqlType = "bytea"
|
||||
if isUUID(dataValue) {
|
||||
sqlType = "uuid"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s postgres) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||
}
|
||||
|
||||
func (postgres) SupportLastInsertID() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isUUID(value reflect.Value) bool {
|
||||
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||
return false
|
||||
}
|
||||
typename := value.Type().Name()
|
||||
lower := strings.ToLower(typename)
|
||||
return "uuid" == lower || "guid" == lower
|
||||
}
|
||||
@ -1,107 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sqlite3 struct {
|
||||
commonDialect
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterDialect("sqlite3", &sqlite3{})
|
||||
}
|
||||
|
||||
func (sqlite3) GetName() string {
|
||||
return "sqlite3"
|
||||
}
|
||||
|
||||
// Get Data Type for Sqlite Dialect
|
||||
func (s *sqlite3) DataTypeOf(field *StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "real"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "datetime"
|
||||
}
|
||||
default:
|
||||
if IsByteArrayOrSlice(dataValue) {
|
||||
sqlType = "blob"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s sqlite3) CurrentDatabase() (name string) {
|
||||
var (
|
||||
ifaces = make([]interface{}, 3)
|
||||
pointers = make([]*string, 3)
|
||||
i int
|
||||
)
|
||||
for i = 0; i < 3; i++ {
|
||||
ifaces[i] = &pointers[i]
|
||||
}
|
||||
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||
return
|
||||
}
|
||||
if pointers[1] != nil {
|
||||
name = *pointers[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,170 +0,0 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func setIdentityInsert(scope *gorm.Scope) {
|
||||
if scope.Dialect().GetName() == "mssql" {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
|
||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||
scope.InstanceSet("mssql:identity_insert_on", true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func turnOffIdentityInsert(scope *gorm.Scope) {
|
||||
if scope.Dialect().GetName() == "mssql" {
|
||||
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
|
||||
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
||||
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
|
||||
gorm.RegisterDialect("mssql", &mssql{})
|
||||
}
|
||||
|
||||
type mssql struct {
|
||||
db gorm.SQLCommon
|
||||
gorm.DefaultForeignKeyNamer
|
||||
}
|
||||
|
||||
func (mssql) GetName() string {
|
||||
return "mssql"
|
||||
}
|
||||
|
||||
func (s *mssql) SetDB(db gorm.SQLCommon) {
|
||||
s.db = db
|
||||
}
|
||||
|
||||
func (mssql) BindVar(i int) string {
|
||||
return "$$$" // ?
|
||||
}
|
||||
|
||||
func (mssql) Quote(key string) string {
|
||||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
|
||||
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
sqlType = "bit"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "int IDENTITY(1,1)"
|
||||
} else {
|
||||
sqlType = "int"
|
||||
}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||
sqlType = "bigint IDENTITY(1,1)"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
sqlType = "float"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 8000 {
|
||||
sqlType = fmt.Sprintf("nvarchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "nvarchar(max)"
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
sqlType = "datetimeoffset"
|
||||
}
|
||||
default:
|
||||
if gorm.IsByteArrayOrSlice(dataValue) {
|
||||
if size > 0 && size < 8000 {
|
||||
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||
} else {
|
||||
sqlType = "varbinary(max)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s mssql) HasTable(tableName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (s mssql) CurrentDatabase() (name string) {
|
||||
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
|
||||
}
|
||||
}
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
if sql == "" {
|
||||
// add default zero offset
|
||||
sql += " OFFSET 0 ROWS"
|
||||
}
|
||||
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) SelectFromDummyTable() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||
return ""
|
||||
}
|
||||
@ -1,3 +0,0 @@
|
||||
package mysql
|
||||
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
@ -1,77 +0,0 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/lib/pq/hstore"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Hstore map[string]*string
|
||||
|
||||
// Value get value of Hstore
|
||||
func (h Hstore) Value() (driver.Value, error) {
|
||||
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||
if len(h) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for key, value := range h {
|
||||
var s sql.NullString
|
||||
if value != nil {
|
||||
s.String = *value
|
||||
s.Valid = true
|
||||
}
|
||||
hstore.Map[key] = s
|
||||
}
|
||||
return hstore.Value()
|
||||
}
|
||||
|
||||
// Scan scan value into Hstore
|
||||
func (h *Hstore) Scan(value interface{}) error {
|
||||
hstore := hstore.Hstore{}
|
||||
|
||||
if err := hstore.Scan(value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hstore.Map) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
*h = Hstore{}
|
||||
for k := range hstore.Map {
|
||||
if hstore.Map[k].Valid {
|
||||
s := hstore.Map[k].String
|
||||
(*h)[k] = &s
|
||||
} else {
|
||||
(*h)[k] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Jsonb Postgresql's JSONB data type
|
||||
type Jsonb struct {
|
||||
json.RawMessage
|
||||
}
|
||||
|
||||
// Value get value of Jsonb
|
||||
func (j Jsonb) Value() (driver.Value, error) {
|
||||
return j.MarshalJSON()
|
||||
}
|
||||
|
||||
// Scan scan value into Jsonb
|
||||
func (j *Jsonb) Scan(value interface{}) error {
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
@ -1,3 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import _ "github.com/mattn/go-sqlite3"
|
||||
@ -1,73 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
|
||||
type BasePost struct {
|
||||
Id int64
|
||||
Title string
|
||||
URL string
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
}
|
||||
|
||||
type HNPost struct {
|
||||
BasePost
|
||||
Author `gorm:"embedded_prefix:user_"` // Embedded struct
|
||||
Upvotes int32
|
||||
}
|
||||
|
||||
type EngadgetPost struct {
|
||||
BasePost BasePost `gorm:"embedded"`
|
||||
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
|
||||
ImageUrl string
|
||||
}
|
||||
|
||||
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
||||
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
||||
engadgetPostScope := DB.NewScope(&EngadgetPost{})
|
||||
if !dialect.HasColumn(engadgetPostScope.TableName(), "author_id") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_name") || !dialect.HasColumn(engadgetPostScope.TableName(), "author_email") {
|
||||
t.Errorf("should has prefix for embedded columns")
|
||||
}
|
||||
|
||||
if len(engadgetPostScope.PrimaryFields()) != 1 {
|
||||
t.Errorf("should have only one primary field with embedded struct, but got %v", len(engadgetPostScope.PrimaryFields()))
|
||||
}
|
||||
|
||||
hnScope := DB.NewScope(&HNPost{})
|
||||
if !dialect.HasColumn(hnScope.TableName(), "user_id") || !dialect.HasColumn(hnScope.TableName(), "user_name") || !dialect.HasColumn(hnScope.TableName(), "user_email") {
|
||||
t.Errorf("should has prefix for embedded columns")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
|
||||
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||
var news HNPost
|
||||
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else if news.Title != "hn_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||
var egNews EngadgetPost
|
||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else if egNews.BasePost.Title != "engadget_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
||||
t.Errorf("primary key with embedded struct should works")
|
||||
}
|
||||
|
||||
for _, field := range DB.NewScope(&HNPost{}).Fields() {
|
||||
if field.Name == "BasePost" {
|
||||
t.Errorf("scope Fields should not contain embedded struct")
|
||||
}
|
||||
}
|
||||
}
|
||||
82
errors.go
82
errors.go
@ -2,59 +2,41 @@ package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||
ErrInvalidSQL = errors.New("invalid SQL")
|
||||
// ErrRecordNotFound record not found error
|
||||
ErrRecordNotFound = logger.ErrRecordNotFound
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||
// ErrUnaddressable unaddressable value
|
||||
ErrUnaddressable = errors.New("using unaddressable value")
|
||||
// ErrNotImplemented not implemented
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
// ErrMissingWhereClause missing where clause
|
||||
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
||||
// ErrUnsupportedRelation unsupported relations
|
||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||
// ErrPrimaryKeyRequired primary keys required
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||
// ErrRegistered registered
|
||||
ErrRegistered = errors.New("registered")
|
||||
// ErrInvalidField invalid field
|
||||
ErrInvalidField = errors.New("invalid field")
|
||||
// ErrEmptySlice empty slice found
|
||||
ErrEmptySlice = errors.New("empty slice found")
|
||||
// ErrDryRunModeUnsupported dry run mode unsupported
|
||||
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
|
||||
// ErrInvalidDB invalid db
|
||||
ErrInvalidDB = errors.New("invalid db")
|
||||
// ErrInvalidValue invalid value
|
||||
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
|
||||
// ErrInvalidValueOfLength invalid values do not match length
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
)
|
||||
|
||||
// Errors contains all happened errors
|
||||
type Errors []error
|
||||
|
||||
// GetErrors gets all happened errors
|
||||
func (errs Errors) GetErrors() []error {
|
||||
return errs
|
||||
}
|
||||
|
||||
// Add adds an error
|
||||
func (errs Errors) Add(newErrors ...error) Errors {
|
||||
for _, err := range newErrors {
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if errors, ok := err.(Errors); ok {
|
||||
errs = errs.Add(errors...)
|
||||
} else {
|
||||
ok = true
|
||||
for _, e := range errs {
|
||||
if err == e {
|
||||
ok = false
|
||||
}
|
||||
}
|
||||
if ok {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// Error format happened errors
|
||||
func (errs Errors) Error() string {
|
||||
var errors = []string{}
|
||||
for _, e := range errs {
|
||||
errors = append(errors, e.Error())
|
||||
}
|
||||
return strings.Join(errors, "; ")
|
||||
}
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestErrorsCanBeUsedOutsideGorm(t *testing.T) {
|
||||
errs := []error{errors.New("First"), errors.New("Second")}
|
||||
|
||||
gErrs := gorm.Errors(errs)
|
||||
gErrs = gErrs.Add(errors.New("Third"))
|
||||
gErrs = gErrs.Add(gErrs)
|
||||
|
||||
if gErrs.Error() != "First; Second; Third" {
|
||||
t.Fatalf("Gave wrong error, got %s", gErrs.Error())
|
||||
}
|
||||
}
|
||||
58
field.go
58
field.go
@ -1,58 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Field model field definition
|
||||
type Field struct {
|
||||
*StructField
|
||||
IsBlank bool
|
||||
Field reflect.Value
|
||||
}
|
||||
|
||||
// Set set a value to the field
|
||||
func (field *Field) Set(value interface{}) (err error) {
|
||||
if !field.Field.IsValid() {
|
||||
return errors.New("field value not valid")
|
||||
}
|
||||
|
||||
if !field.Field.CanAddr() {
|
||||
return ErrUnaddressable
|
||||
}
|
||||
|
||||
reflectValue, ok := value.(reflect.Value)
|
||||
if !ok {
|
||||
reflectValue = reflect.ValueOf(value)
|
||||
}
|
||||
|
||||
fieldValue := field.Field
|
||||
if reflectValue.IsValid() {
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else {
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||
err = scanner.Scan(reflectValue.Interface())
|
||||
} else {
|
||||
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||
}
|
||||
|
||||
field.IsBlank = isBlank(field.Field)
|
||||
return err
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type CalculateField struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Children []CalculateFieldChild
|
||||
Category CalculateFieldCategory
|
||||
EmbeddedField
|
||||
}
|
||||
|
||||
type EmbeddedField struct {
|
||||
EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"`
|
||||
}
|
||||
|
||||
type CalculateFieldChild struct {
|
||||
gorm.Model
|
||||
CalculateFieldID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
type CalculateFieldCategory struct {
|
||||
gorm.Model
|
||||
CalculateFieldID uint
|
||||
Name string
|
||||
}
|
||||
|
||||
func TestCalculateField(t *testing.T) {
|
||||
var field CalculateField
|
||||
var scope = DB.NewScope(&field)
|
||||
if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
|
||||
t.Errorf("Should calculate fields correctly for the first time")
|
||||
}
|
||||
|
||||
if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
|
||||
t.Errorf("Should calculate fields correctly for the first time")
|
||||
}
|
||||
|
||||
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||
t.Errorf("should find embedded field")
|
||||
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
||||
t.Errorf("should find embedded field's tag settings")
|
||||
}
|
||||
}
|
||||
649
finisher_api.go
Normal file
649
finisher_api.go
Normal file
@ -0,0 +1,649 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Create insert the value into database
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
if db.CreateBatchSize > 0 {
|
||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||
}
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// CreateInBatches insert the value in batches into database
|
||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var rowsAffected int64
|
||||
tx = db.getInstance()
|
||||
|
||||
callFc := func(tx *DB) error {
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
for i := 0; i < reflectLen; i += batchSize {
|
||||
ends := i + batchSize
|
||||
if ends > reflectLen {
|
||||
ends = reflectLen
|
||||
}
|
||||
|
||||
subtx := tx.getInstance()
|
||||
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
|
||||
subtx.callbacks.Create().Execute(subtx)
|
||||
if subtx.Error != nil {
|
||||
return subtx.Error
|
||||
}
|
||||
rowsAffected += subtx.RowsAffected
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx.SkipDefaultTransaction {
|
||||
tx.AddError(callFc(tx.Session(&Session{})))
|
||||
} else {
|
||||
tx.AddError(tx.Transaction(callFc))
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
default:
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||
}
|
||||
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
selectedUpdate := len(tx.Statement.Selects) != 0
|
||||
// when updating, use all fields including those zero-value fields
|
||||
if !selectedUpdate {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
return tx.Create(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// First find first record that match given conditions, order by primary key
|
||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1)
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Last find last record that match given conditions, order by primary key
|
||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Desc: true,
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Find find records that match given conditions
|
||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// FindInBatches find records in batches
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||
var (
|
||||
tx = db.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
}).Session(&Session{})
|
||||
queryDB = tx
|
||||
rowsAffected int64
|
||||
batch int
|
||||
)
|
||||
|
||||
for {
|
||||
result := queryDB.Limit(batchSize).Find(dest)
|
||||
rowsAffected += result.RowsAffected
|
||||
batch++
|
||||
|
||||
if result.Error == nil && result.RowsAffected != 0 {
|
||||
tx.AddError(fc(result, batch))
|
||||
} else if result.Error != nil {
|
||||
tx.AddError(result.Error)
|
||||
}
|
||||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
} else {
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
} else {
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
return tx
|
||||
}
|
||||
|
||||
func (tx *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
for _, value := range values {
|
||||
switch v := value.(type) {
|
||||
case []clause.Expression:
|
||||
for _, expr := range v {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := tx.Statement.Schema.LookUpField(column); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
}
|
||||
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||
tx.assignInterfacesToValue(andCond.Exprs)
|
||||
}
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
}
|
||||
default:
|
||||
if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, f := range s.Fields {
|
||||
if f.Readable {
|
||||
if v, isZero := f.ValueOf(reflectValue); !isZero {
|
||||
if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||
tx.AddError(field.Set(tx.Statement.ReflectValue, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(values) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
||||
tx.assignInterfacesToValue(exprs)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.Statement.SkipHooks = true
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.Statement.SkipHooks = true
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = value
|
||||
tx.callbacks.Delete().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
defer func() {
|
||||
tx.Statement.Model = nil
|
||||
}()
|
||||
}
|
||||
|
||||
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
||||
defer func() {
|
||||
db.Statement.Clauses["SELECT"] = selectClause
|
||||
}()
|
||||
} else {
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
|
||||
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
|
||||
expr := clause.Expr{SQL: "count(1)"}
|
||||
|
||||
if len(tx.Statement.Selects) == 1 {
|
||||
dbName := tx.Statement.Selects[0]
|
||||
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
||||
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||
dbName = f.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if tx.Statement.Distinct {
|
||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
} else {
|
||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||
}
|
||||
|
||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
|
||||
delete(db.Statement.Clauses, "ORDER BY")
|
||||
defer func() {
|
||||
db.Statement.Clauses["ORDER BY"] = orderByClause
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
if tx.RowsAffected != 1 {
|
||||
*count = tx.RowsAffected
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance().InstanceSet("rows", false)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.getInstance().InstanceSet("rows", true)
|
||||
tx.callbacks.Row().Execute(tx)
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
tx.Error = ErrDryRunModeUnsupported
|
||||
}
|
||||
return rows, tx.Error
|
||||
}
|
||||
|
||||
// Scan scan value to a struct
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
config := *db.Config
|
||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||
config.Logger = newLogger
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Config = &config
|
||||
|
||||
if rows, err := tx.Rows(); err != nil {
|
||||
tx.AddError(err)
|
||||
} else {
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
}
|
||||
}
|
||||
|
||||
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
||||
return newLogger.SQL, tx.RowsAffected
|
||||
}, tx.Error)
|
||||
tx.Logger = currentLogger
|
||||
return
|
||||
}
|
||||
|
||||
// Pluck used to query single column from a model as a map
|
||||
// var ages []int64
|
||||
// db.Find(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model != nil {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
||||
column = f.DBName
|
||||
}
|
||||
}
|
||||
} else if tx.Statement.Table == "" {
|
||||
tx.AddError(ErrModelValueRequired)
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) != 1 {
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
})
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.callbacks.Query().Execute(tx)
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
tx := db.getInstance()
|
||||
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||
tx.AddError(err)
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem()
|
||||
}
|
||||
Scan(rows, tx, true)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = fc(db.Session(&Session{}))
|
||||
}
|
||||
} else {
|
||||
tx := db.Begin(opts...)
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = tx.Error; err == nil {
|
||||
err = fc(tx)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = tx.Commit().Error
|
||||
}
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return
|
||||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
var (
|
||||
// clone statement
|
||||
tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
|
||||
opt *sql.TxOptions
|
||||
err error
|
||||
)
|
||||
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else {
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Rollback rollback a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Rollback())
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Exec execute raw sql
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
|
||||
tx.callbacks.Raw().Execute(tx)
|
||||
return
|
||||
}
|
||||
8
go.mod
Normal file
8
go.mod
Normal file
@ -0,0 +1,8 @@
|
||||
module gorm.io/gorm
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.2
|
||||
)
|
||||
4
go.sum
Normal file
4
go.sum
Normal file
@ -0,0 +1,4 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
|
||||
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
441
gorm.go
Normal file
441
gorm.go
Normal file
@ -0,0 +1,441 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// for Config.cacheStore store PreparedStmtDB key
|
||||
const preparedStmtDBKey = "preparedStmt"
|
||||
|
||||
// Config GORM config
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
FullSaveAssociations bool
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
NowFunc func() time.Time
|
||||
// DryRun generate sql without execute
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
AllowGlobalUpdate bool
|
||||
// QueryFields executes the SQL query with all fields of the table
|
||||
QueryFields bool
|
||||
// CreateBatchSize default create batch size
|
||||
CreateBatchSize int
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
// ConnPool db conn pool
|
||||
ConnPool ConnPool
|
||||
// Dialector database dialector
|
||||
Dialector
|
||||
// Plugins registered plugins
|
||||
Plugins map[string]Plugin
|
||||
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
func (c *Config) Apply(config *Config) error {
|
||||
if config != c {
|
||||
*config = *c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) AfterInitialize(db *DB) error {
|
||||
if db != nil {
|
||||
for _, plugin := range c.Plugins {
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Option interface {
|
||||
Apply(*Config) error
|
||||
AfterInitialize(*DB) error
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
clone int
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
NewDB bool
|
||||
SkipHooks bool
|
||||
SkipDefaultTransaction bool
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
CreateBatchSize int
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
config := &Config{}
|
||||
|
||||
sort.Slice(opts, func(i, j int) bool {
|
||||
_, isConfig := opts[i].(*Config)
|
||||
_, isConfig2 := opts[j].(*Config)
|
||||
return isConfig && !isConfig2
|
||||
})
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if err := opt.Apply(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(opt Option) {
|
||||
if errr := opt.AfterInitialize(db); errr != nil {
|
||||
err = errr
|
||||
}
|
||||
}(opt)
|
||||
}
|
||||
}
|
||||
|
||||
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
|
||||
if err = d.Apply(config); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{}
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = logger.Default
|
||||
}
|
||||
|
||||
if config.NowFunc == nil {
|
||||
config.NowFunc = func() time.Time { return time.Now().Local() }
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
config.Dialector = dialector
|
||||
}
|
||||
|
||||
if config.Plugins == nil {
|
||||
config.Plugins = map[string]Plugin{}
|
||||
}
|
||||
|
||||
if config.cacheStore == nil {
|
||||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
db = &DB{Config: config, clone: 1}
|
||||
|
||||
db.callbacks = initializeCallbacks(db)
|
||||
|
||||
if config.ClauseBuilders == nil {
|
||||
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
|
||||
}
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
}
|
||||
|
||||
preparedStmt := &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
Stmts: map[string]Stmt{},
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
|
||||
if config.PrepareStmt {
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
db.Statement = &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Context: context.Background(),
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
|
||||
if err == nil && !config.DisableAutomaticPing {
|
||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||
err = pinger.Ping()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Session create new db session
|
||||
func (db *DB) Session(config *Session) *DB {
|
||||
var (
|
||||
txConfig = *db.Config
|
||||
tx = &DB{
|
||||
Config: &txConfig,
|
||||
Statement: db.Statement,
|
||||
Error: db.Error,
|
||||
clone: 1,
|
||||
}
|
||||
)
|
||||
if config.CreateBatchSize > 0 {
|
||||
tx.Config.CreateBatchSize = config.CreateBatchSize
|
||||
}
|
||||
|
||||
if config.SkipDefaultTransaction {
|
||||
tx.Config.SkipDefaultTransaction = true
|
||||
}
|
||||
|
||||
if config.AllowGlobalUpdate {
|
||||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.FullSaveAssociations {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
tx.Statement.SkipHooks = true
|
||||
}
|
||||
|
||||
if config.DisableNestedTransaction {
|
||||
txConfig.DisableNestedTransaction = true
|
||||
}
|
||||
|
||||
if !config.NewDB {
|
||||
tx.clone = 2
|
||||
}
|
||||
|
||||
if config.DryRun {
|
||||
tx.Config.DryRun = true
|
||||
}
|
||||
|
||||
if config.QueryFields {
|
||||
tx.Config.QueryFields = true
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
tx.Config.Logger = config.Logger
|
||||
}
|
||||
|
||||
if config.NowFunc != nil {
|
||||
tx.Config.NowFunc = config.NowFunc
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// WithContext change current instance db's context to ctx
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
return db.Session(&Session{Context: ctx})
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
return db.Session(&Session{
|
||||
Logger: db.Logger.LogMode(logger.Info),
|
||||
})
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
func (db *DB) Set(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// Get get value with key from current db instance's context
|
||||
func (db *DB) Get(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(key)
|
||||
}
|
||||
|
||||
// InstanceSet store value with key into current db instance's context
|
||||
func (db *DB) InstanceSet(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// InstanceGet get value with key from current db instance's context
|
||||
func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
}
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else if err != nil {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone > 0 {
|
||||
tx := &DB{Config: db.Config, Error: db.Error}
|
||||
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
tx.Statement = db.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||
return clause.Expr{SQL: expr, Vars: args}
|
||||
}
|
||||
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
if err := stmt.Parse(model); err == nil {
|
||||
modelSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := stmt.Parse(joinTable); err == nil {
|
||||
joinSchema = stmt.Schema
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
|
||||
for _, ref := range relation.References {
|
||||
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
f.GORMDataType = ref.ForeignKey.GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = ref.ForeignKey.Size
|
||||
}
|
||||
ref.ForeignKey = f
|
||||
} else {
|
||||
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for name, rel := range relation.JoinTable.Relationships.Relations {
|
||||
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
|
||||
rel.Schema = joinSchema
|
||||
joinSchema.Relationships.Relations[name] = rel
|
||||
}
|
||||
}
|
||||
|
||||
relation.JoinTable = joinSchema
|
||||
} else {
|
||||
return fmt.Errorf("failed to found relation: %v", field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Use(plugin Plugin) error {
|
||||
name := plugin.Name()
|
||||
if _, ok := db.Plugins[name]; ok {
|
||||
return ErrRegistered
|
||||
}
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
db.Plugins[name] = plugin
|
||||
return nil
|
||||
}
|
||||
20
interface.go
20
interface.go
@ -1,20 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
|
||||
type SQLCommon interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type sqlDb interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type sqlTx interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
63
interfaces.go
Normal file
63
interfaces.go
Normal file
@ -0,0 +1,63 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
Migrator(db *DB) Migrator
|
||||
DataTypeOf(*schema.Field) string
|
||||
DefaultValueOf(*schema.Field) clause.Expression
|
||||
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
|
||||
QuoteTo(clause.Writer, string)
|
||||
Explain(sql string, vars ...interface{}) string
|
||||
}
|
||||
|
||||
// Plugin GORM plugin interface
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// SavePointerDialectorInterface save pointer interface
|
||||
type SavePointerDialectorInterface interface {
|
||||
SavePoint(tx *DB, name string) error
|
||||
RollbackTo(tx *DB, name string) error
|
||||
}
|
||||
|
||||
type TxBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type ConnPoolBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||
}
|
||||
|
||||
type TxCommitter interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
||||
type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
@ -1,223 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||
type JoinTableHandlerInterface interface {
|
||||
// initialize join table handler
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
// Table return join table's table name
|
||||
Table(db *DB) string
|
||||
// Add create relationship in join table for source and destination
|
||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
// Delete delete relationship in join table for sources
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
// JoinWith query with `Join` conditions
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
// SourceForeignKeys return source foreign keys
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableForeignKey join table foreign key struct
|
||||
type JoinTableForeignKey struct {
|
||||
DBName string
|
||||
AssociationDBName string
|
||||
}
|
||||
|
||||
// JoinTableSource is a struct that contains model type and foreign keys
|
||||
type JoinTableSource struct {
|
||||
ModelType reflect.Type
|
||||
ForeignKeys []JoinTableForeignKey
|
||||
}
|
||||
|
||||
// JoinTableHandler default join table handler
|
||||
type JoinTableHandler struct {
|
||||
TableName string `sql:"-"`
|
||||
Source JoinTableSource `sql:"-"`
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
// SourceForeignKeys return source foreign keys
|
||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||
return s.Source.ForeignKeys
|
||||
}
|
||||
|
||||
// DestinationForeignKeys return destination foreign keys
|
||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||
return s.Destination.ForeignKeys
|
||||
}
|
||||
|
||||
// Setup initialize a default join table handler
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
s.Source = JoinTableSource{ModelType: source}
|
||||
s.Source.ForeignKeys = []JoinTableForeignKey{}
|
||||
for idx, dbName := range relationship.ForeignFieldNames {
|
||||
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.ForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
|
||||
s.Destination = JoinTableSource{ModelType: destination}
|
||||
s.Destination.ForeignKeys = []JoinTableForeignKey{}
|
||||
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.AssociationForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Table return join table's table name
|
||||
func (s JoinTableHandler) Table(db *DB) string {
|
||||
return s.TableName
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
values := map[string]interface{}{}
|
||||
|
||||
for _, source := range sources {
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
} else if s.Destination.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
values[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Add create relationship in join table for source and destination
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := map[string]interface{}{}
|
||||
|
||||
// getSearchMap() cannot be used here since the source and destination
|
||||
// model types may be identical
|
||||
|
||||
sourceScope := db.NewScope(source)
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
if field, ok := sourceScope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
destinationScope := db.NewScope(destination)
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
if field, ok := destinationScope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
searchMap[foreignKey.DBName] = field.Field.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
for key, value := range searchMap {
|
||||
assignColumns = append(assignColumns, scope.Quote(key))
|
||||
binVars = append(binVars, `?`)
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
quotedTable := scope.Quote(handler.Table(db))
|
||||
sql := fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
|
||||
quotedTable,
|
||||
strings.Join(assignColumns, ","),
|
||||
strings.Join(binVars, ","),
|
||||
scope.Dialect().SelectFromDummyTable(),
|
||||
quotedTable,
|
||||
strings.Join(conditions, " AND "),
|
||||
)
|
||||
|
||||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
// Delete delete relationship in join table for sources
|
||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||
var (
|
||||
scope = db.NewScope(nil)
|
||||
conditions []string
|
||||
values []interface{}
|
||||
)
|
||||
|
||||
for key, value := range s.getSearchMap(db, sources...) {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
// JoinWith query with `Join` conditions
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
var (
|
||||
scope = db.NewScope(source)
|
||||
tableName = handler.Table(db)
|
||||
quotedTableName = scope.Quote(tableName)
|
||||
joinConditions []string
|
||||
values []interface{}
|
||||
)
|
||||
|
||||
if s.Source.ModelType == scope.GetModelStruct().ModelType {
|
||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
var foreignDBNames []string
|
||||
var foreignFieldNames []string
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
|
||||
var condString string
|
||||
if len(foreignFieldValues) > 0 {
|
||||
var quotedForeignDBNames []string
|
||||
for _, dbName := range foreignDBNames {
|
||||
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
|
||||
}
|
||||
|
||||
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
||||
|
||||
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||
values = append(values, toQueryValues(keys))
|
||||
} else {
|
||||
condString = fmt.Sprintf("1 <> 1")
|
||||
}
|
||||
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||
}
|
||||
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
@ -1,117 +0,0 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type Person struct {
|
||||
Id int
|
||||
Name string
|
||||
Addresses []*Address `gorm:"many2many:person_addresses;"`
|
||||
}
|
||||
|
||||
type PersonAddress struct {
|
||||
gorm.JoinTableHandler
|
||||
PersonID int
|
||||
AddressID int
|
||||
DeletedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
|
||||
foreignPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(foreignValue).PrimaryKeyValue()))
|
||||
associationPrimaryKey, _ := strconv.Atoi(fmt.Sprint(db.NewScope(associationValue).PrimaryKeyValue()))
|
||||
if result := db.Unscoped().Model(&PersonAddress{}).Where(map[string]interface{}{
|
||||
"person_id": foreignPrimaryKey,
|
||||
"address_id": associationPrimaryKey,
|
||||
}).Update(map[string]interface{}{
|
||||
"person_id": foreignPrimaryKey,
|
||||
"address_id": associationPrimaryKey,
|
||||
"deleted_at": gorm.Expr("NULL"),
|
||||
}).RowsAffected; result == 0 {
|
||||
return db.Create(&PersonAddress{
|
||||
PersonID: foreignPrimaryKey,
|
||||
AddressID: associationPrimaryKey,
|
||||
}).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
|
||||
return db.Delete(&PersonAddress{}).Error
|
||||
}
|
||||
|
||||
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
|
||||
table := pa.Table(db)
|
||||
return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||
}
|
||||
|
||||
func TestJoinTable(t *testing.T) {
|
||||
DB.Exec("drop table person_addresses;")
|
||||
DB.AutoMigrate(&Person{})
|
||||
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
|
||||
|
||||
address1 := &Address{Address1: "address 1"}
|
||||
address2 := &Address{Address1: "address 2"}
|
||||
person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
|
||||
DB.Save(person)
|
||||
|
||||
DB.Model(person).Association("Addresses").Delete(address1)
|
||||
|
||||
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
|
||||
t.Errorf("Should found one address")
|
||||
}
|
||||
|
||||
if DB.Model(person).Association("Addresses").Count() != 1 {
|
||||
t.Errorf("Should found one address")
|
||||
}
|
||||
|
||||
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
|
||||
t.Errorf("Found two addresses with Unscoped")
|
||||
}
|
||||
|
||||
if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
|
||||
t.Errorf("Should deleted all addresses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedMany2ManyRelationship(t *testing.T) {
|
||||
type EmbeddedPerson struct {
|
||||
ID int
|
||||
Name string
|
||||
Addresses []*Address `gorm:"many2many:person_addresses;"`
|
||||
}
|
||||
|
||||
type NewPerson struct {
|
||||
EmbeddedPerson
|
||||
ExternalID uint
|
||||
}
|
||||
DB.Exec("drop table person_addresses;")
|
||||
DB.AutoMigrate(&NewPerson{})
|
||||
|
||||
address1 := &Address{Address1: "address 1"}
|
||||
address2 := &Address{Address1: "address 2"}
|
||||
person := &NewPerson{ExternalID: 100, EmbeddedPerson: EmbeddedPerson{Name: "person", Addresses: []*Address{address1, address2}}}
|
||||
if err := DB.Save(person).Error; err != nil {
|
||||
t.Errorf("no error should return when save embedded many2many relationship, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(person).Association("Addresses").Delete(address1).Error; err != nil {
|
||||
t.Errorf("no error should return when delete embedded many2many relationship, but got %v", err)
|
||||
}
|
||||
|
||||
association := DB.Model(person).Association("Addresses")
|
||||
if count := association.Count(); count != 1 || association.Error != nil {
|
||||
t.Errorf("Should found one address, but got %v, error is %v", count, association.Error)
|
||||
}
|
||||
|
||||
if association.Clear(); association.Count() != 0 {
|
||||
t.Errorf("Should deleted all addresses")
|
||||
}
|
||||
}
|
||||
119
logger.go
119
logger.go
@ -1,119 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||
sqlRegexp = regexp.MustCompile(`\?`)
|
||||
numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`)
|
||||
)
|
||||
|
||||
func isPrintable(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
||||
if len(values) > 1 {
|
||||
var (
|
||||
sql string
|
||||
formattedValues []string
|
||||
level = values[0]
|
||||
currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
||||
source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
|
||||
)
|
||||
|
||||
messages = []interface{}{source, currentTime}
|
||||
|
||||
if level == "sql" {
|
||||
// duration
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
// sql
|
||||
|
||||
for _, value := range values[4].([]interface{}) {
|
||||
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if indirectValue.IsValid() {
|
||||
value = indirectValue.Interface()
|
||||
if t, ok := value.(time.Time); ok {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05")))
|
||||
} else if b, ok := value.([]byte); ok {
|
||||
if str := string(b); isPrintable(str) {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "'<binary>'")
|
||||
}
|
||||
} else if r, ok := value.(driver.Valuer); ok {
|
||||
if value, err := r.Value(); err == nil && value != nil {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "NULL")
|
||||
}
|
||||
} else {
|
||||
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||
}
|
||||
} else {
|
||||
formattedValues = append(formattedValues, "NULL")
|
||||
}
|
||||
}
|
||||
|
||||
// differentiate between $n placeholders or else treat like ?
|
||||
if numericPlaceHolderRegexp.MatchString(values[3].(string)) {
|
||||
sql = values[3].(string)
|
||||
for index, value := range formattedValues {
|
||||
placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1)
|
||||
sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1")
|
||||
}
|
||||
} else {
|
||||
formattedValuesLength := len(formattedValues)
|
||||
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
||||
sql += value
|
||||
if index < formattedValuesLength {
|
||||
sql += formattedValues[index]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, sql)
|
||||
messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned "))
|
||||
} else {
|
||||
messages = append(messages, "\033[31;1m")
|
||||
messages = append(messages, values[2:]...)
|
||||
messages = append(messages, "\033[0m")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type logger interface {
|
||||
Print(v ...interface{})
|
||||
}
|
||||
|
||||
// LogWriter log writer interface
|
||||
type LogWriter interface {
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
// Logger default logger
|
||||
type Logger struct {
|
||||
LogWriter
|
||||
}
|
||||
|
||||
// Print format & print log
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
logger.Println(LogFormatter(values...)...)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user