Compare commits
1 Commits
master
...
allow_shar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
118a3f469e |
20
.github/release-drafter.yml
vendored
20
.github/release-drafter.yml
vendored
@ -1,20 +0,0 @@
|
||||
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
|
||||
tag-template: 'v$NEXT_PATCH_VERSION'
|
||||
categories:
|
||||
- title: '🚀 Features'
|
||||
labels:
|
||||
- 'feature'
|
||||
- 'enhancement'
|
||||
- title: '🐛 Bug Fixes'
|
||||
labels:
|
||||
- 'fix'
|
||||
- 'bugfix'
|
||||
- 'bug'
|
||||
- title: '🧰 Maintenance'
|
||||
label: 'chore'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&'
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
||||
31
.github/workflows/create-release.yml
vendored
31
.github/workflows/create-release.yml
vendored
@ -1,31 +0,0 @@
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
create_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate Release Notes and Publish
|
||||
id: generate_release_notes
|
||||
uses: release-drafter/release-drafter@v6
|
||||
with:
|
||||
config-name: 'release-drafter.yml'
|
||||
name: "Release ${{ github.ref_name }}"
|
||||
tag: ${{ github.ref_name }}
|
||||
publish: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
26
.github/workflows/golangci-lint.yml
vendored
26
.github/workflows/golangci-lint.yml
vendored
@ -1,26 +0,0 @@
|
||||
name: golangci-lint
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: stable
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v2.0
|
||||
only-new-issues: true
|
||||
2
.github/workflows/invalid_question.yml
vendored
2
.github/workflows/invalid_question.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v8
|
||||
uses: actions/stale@v5
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -11,7 +11,7 @@ jobs:
|
||||
name: Label issues and pull requests
|
||||
steps:
|
||||
- name: check out
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: labeler
|
||||
uses: jinzhu/super-labeler-action@develop
|
||||
|
||||
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v8
|
||||
uses: actions/stale@v5
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||
|
||||
22
.github/workflows/reviewdog.yml
vendored
Normal file
22
.github/workflows/reviewdog.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: reviewdog
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
golangci-lint:
|
||||
name: runner / golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
- name: golangci-lint
|
||||
uses: reviewdog/action-golangci-lint@v2
|
||||
|
||||
- name: Setup reviewdog
|
||||
uses: reviewdog/action-setup@v1
|
||||
|
||||
- name: gofumpt -s with reviewdog
|
||||
env:
|
||||
REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
go install mvdan.cc/gofumpt@v0.2.0
|
||||
gofumpt -e -d . | \
|
||||
reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
ACTIONS_STEP_DEBUG: true
|
||||
steps:
|
||||
- name: Close Stale Issues
|
||||
uses: actions/stale@v8
|
||||
uses: actions/stale@v5
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||
|
||||
193
.github/workflows/tests.yml
vendored
193
.github/workflows/tests.yml
vendored
@ -16,21 +16,21 @@ jobs:
|
||||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.23', '1.24']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
platform: [ubuntu-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -41,8 +41,8 @@ jobs:
|
||||
mysql:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
|
||||
go: ['1.23', '1.24']
|
||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -65,15 +65,16 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -81,54 +82,11 @@ jobs:
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
mariadb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'mariadb:latest' ]
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
MYSQL_DATABASE: gorm
|
||||
MYSQL_USER: gorm
|
||||
MYSQL_PASSWORD: gorm
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||
ports:
|
||||
- 9910:3306
|
||||
options: >-
|
||||
--health-cmd "mariadb-admin ping -ugorm -pgorm"
|
||||
--health-interval 10s
|
||||
--health-start-period 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
|
||||
|
||||
postgres:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
|
||||
go: ['1.23', '1.24']
|
||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -151,15 +109,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
@ -170,21 +128,23 @@ jobs:
|
||||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.23', '1.24']
|
||||
go: ['1.19', '1.18', '1.17', '1.16']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mssql:
|
||||
image: mcr.microsoft.com/mssql/server:2022-latest
|
||||
image: mcmoe/mssqldocker:latest
|
||||
env:
|
||||
TZ: Asia/Shanghai
|
||||
ACCEPT_EULA: Y
|
||||
MSSQL_SA_PASSWORD: LoremIpsum86
|
||||
SA_PASSWORD: LoremIpsum86
|
||||
MSSQL_DB: gorm
|
||||
MSSQL_USER: gorm
|
||||
MSSQL_PASSWORD: LoremIpsum86
|
||||
ports:
|
||||
- 9930:1433
|
||||
options: >-
|
||||
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
|
||||
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
|
||||
--health-start-period 10s
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
@ -192,119 +152,18 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
|
||||
|
||||
tidb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'v6.5.0' ]
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
steps:
|
||||
- name: Setup TiDB
|
||||
uses: Icemap/tidb-action@main
|
||||
with:
|
||||
port: 9940
|
||||
version: ${{matrix.dbversion}}
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
|
||||
|
||||
gaussdb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
gaussdb:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
# GaussDB has password limitations
|
||||
GS_PASSWORD: Gaussdb@123
|
||||
TZ: Asia/Shanghai
|
||||
ports:
|
||||
- 9950:5432
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Waiting for GaussDB to be ready
|
||||
run: |
|
||||
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
|
||||
if [ -z "$container_name" ]; then
|
||||
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
|
||||
exit 1
|
||||
fi
|
||||
max_retries=12
|
||||
retry_count=0
|
||||
if [ -t 0 ]; then
|
||||
TTY_FLAG="-t"
|
||||
else
|
||||
TTY_FLAG=""
|
||||
fi
|
||||
while [ $retry_count -lt $max_retries ]; do
|
||||
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
|
||||
then
|
||||
echo "Creating database gorm..."
|
||||
sql_file='/tmp/create_database.sql'
|
||||
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
|
||||
docker cp "${sql_file}" "${container_name}":"${sql_file}"
|
||||
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
|
||||
echo "Database initialization completed."
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
|
||||
sleep 10
|
||||
((++retry_count))
|
||||
done
|
||||
exit 0
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- cyclop
|
||||
- exportloopref
|
||||
- gocritic
|
||||
- gosec
|
||||
- ineffassign
|
||||
@ -11,9 +9,3 @@ linters:
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- goimports
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to participate in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community includes:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period. This
|
||||
includes avoiding interactions in community spaces and external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any interaction or public
|
||||
communication with the community for a specified period. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
@ -1,6 +1,6 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
11
README.md
11
README.md
@ -4,6 +4,9 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[](https://github.com/go-gorm/gorm/actions)
|
||||
[](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opencollective.com/gorm)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
@ -27,18 +30,14 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
||||
|
||||
## Contributing
|
||||
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## Contributors
|
||||
|
||||
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||
|
||||
## License
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||
|
||||
@ -14,7 +14,6 @@ import (
|
||||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Unscope bool
|
||||
Error error
|
||||
}
|
||||
|
||||
@ -41,15 +40,6 @@ func (db *DB) Association(column string) *Association {
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Unscoped() *Association {
|
||||
return &Association{
|
||||
DB: association.DB,
|
||||
Relationship: association.Relationship,
|
||||
Error: association.Error,
|
||||
Unscope: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
@ -74,30 +64,14 @@ func (association *Association) Append(values ...interface{}) error {
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
|
||||
var oldBelongsToExpr clause.Expression
|
||||
// we have to record the old BelongsTo value
|
||||
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||
}
|
||||
}
|
||||
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
@ -117,9 +91,6 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
if association.Unscope && oldBelongsToExpr != nil {
|
||||
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
@ -148,11 +119,7 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
if association.Unscope {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||
} else {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
@ -217,8 +184,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
associationDB := association.DB.Session(&Session{})
|
||||
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||
@ -232,21 +198,8 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := association.DB.Model(model)
|
||||
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||
@ -259,11 +212,7 @@ func (association *Association) Delete(values ...interface{}) error {
|
||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
if association.Unscope {
|
||||
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||
} else {
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
}
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
@ -396,10 +345,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
@ -408,13 +353,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
var fieldValue reflect.Value
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
if clear {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
|
||||
} else {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
|
||||
reflect.Copy(fieldValue, oldFieldValue)
|
||||
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
|
||||
}
|
||||
|
||||
appendToFieldValues := func(ev reflect.Value) {
|
||||
@ -437,10 +378,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
appendToFieldValues(rv.Addr())
|
||||
}
|
||||
|
||||
@ -518,9 +455,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
@ -542,9 +476,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
|
||||
37
callbacks.go
37
callbacks.go
@ -75,7 +75,11 @@ func (cs *callbacks) Raw() *processor {
|
||||
func (p *processor) Execute(db *DB) *DB {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
db = db.executeScopes()
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
@ -89,10 +93,6 @@ func (p *processor) Execute(db *DB) *DB {
|
||||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
@ -132,11 +132,7 @@ func (p *processor) Execute(db *DB) *DB {
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
|
||||
@ -187,18 +183,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
@ -257,7 +245,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.SliceStable(cs, func(i, j int) bool {
|
||||
sort.Slice(cs, func(i, j int) bool {
|
||||
if cs[j].before == "*" && cs[i].before != "*" {
|
||||
return true
|
||||
}
|
||||
@ -347,14 +335,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
||||
@ -47,44 +47,29 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||
if !isPtr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
objs = append(objs, obj)
|
||||
elems = reflect.Append(elems, rv)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, rv)
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, rv)
|
||||
} else {
|
||||
elems = reflect.Append(elems, rv.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
@ -126,7 +111,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
@ -195,7 +180,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
@ -223,10 +208,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
identityMap[cacheKey] = true
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
@ -268,11 +250,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
@ -312,10 +294,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
identityMap[cacheKey] = true
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
|
||||
@ -13,20 +13,11 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||
fc(value.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
if db.Statement.ReflectValue.CanAddr() {
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
}
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,6 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -53,13 +52,9 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
if field.Readable {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
if len(fromColumns) > 0 {
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -93,10 +88,6 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@ -111,70 +102,13 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
|
||||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
|
||||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if !supportReturning {
|
||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
if !insertOk {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -187,10 +121,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -200,16 +134,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -318,15 +252,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -349,7 +281,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
@ -370,15 +302,14 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var schemaCache = &sync.Map{}
|
||||
|
||||
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
|
||||
type user struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
Email string `gorm:"default:(-)"`
|
||||
Age int `gorm:"default:(-)"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("parse schema error: %v, is not expected", err)
|
||||
return
|
||||
}
|
||||
dest := []*user{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "alice",
|
||||
Email: "email",
|
||||
Age: 18,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "bob",
|
||||
Email: "email",
|
||||
Age: 19,
|
||||
},
|
||||
}
|
||||
stmt := &gorm.Statement{
|
||||
DB: &gorm.DB{
|
||||
Config: &gorm.Config{
|
||||
NowFunc: func() time.Time { return time.Time{} },
|
||||
},
|
||||
Statement: &gorm.Statement{
|
||||
Settings: sync.Map{},
|
||||
Schema: s,
|
||||
},
|
||||
},
|
||||
ReflectValue: reflect.ValueOf(dest),
|
||||
Dest: dest,
|
||||
}
|
||||
|
||||
stmt.Schema = s
|
||||
|
||||
values := ConvertToCreateValues(stmt)
|
||||
expected := clause.Values{
|
||||
// column has value + defaultValue column has value (which should have a stable order)
|
||||
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
|
||||
Values: [][]interface{}{
|
||||
{"alice", "email", 18, 1},
|
||||
{"bob", "email", 19, 2},
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(expected, values) {
|
||||
t.Errorf("expected: %v got %v", expected, values)
|
||||
}
|
||||
}
|
||||
@ -157,14 +157,8 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if !ok {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@ -172,10 +166,6 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,157 +0,0 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMapToValuesForCreate(t *testing.T) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
input map[string]interface{}
|
||||
expect clause.Values
|
||||
}{
|
||||
{
|
||||
name: "Test convert string value",
|
||||
input: map[string]interface{}{
|
||||
"name": "my name",
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
Values: [][]interface{}{{"my name"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert int value",
|
||||
input: map[string]interface{}{
|
||||
"age": 18,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "age"}},
|
||||
Values: [][]interface{}{{18}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert float value",
|
||||
input: map[string]interface{}{
|
||||
"score": 99.5,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "score"}},
|
||||
Values: [][]interface{}{{99.5}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert bool value",
|
||||
input: map[string]interface{}{
|
||||
"active": true,
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "active"}},
|
||||
Values: [][]interface{}{{true}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||
if !reflect.DeepEqual(actual, tc.expect) {
|
||||
t.Errorf("expect %v got %v", tc.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
input []map[string]interface{}
|
||||
expect clause.Values
|
||||
}{
|
||||
{
|
||||
name: "Test convert slice of string value",
|
||||
input: []map[string]interface{}{
|
||||
{"name": "my name"},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "name"}},
|
||||
Values: [][]interface{}{{"my name"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of int value",
|
||||
input: []map[string]interface{}{
|
||||
{"age": 18},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "age"}},
|
||||
Values: [][]interface{}{{18}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of float value",
|
||||
input: []map[string]interface{}{
|
||||
{"score": 99.5},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "score"}},
|
||||
Values: [][]interface{}{{99.5}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test convert slice of bool value",
|
||||
input: []map[string]interface{}{
|
||||
{"active": true},
|
||||
},
|
||||
expect: clause.Values{
|
||||
Columns: []clause.Column{{Name: "active"}},
|
||||
Values: [][]interface{}{{true}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
|
||||
|
||||
if !reflect.DeepEqual(actual, tc.expect) {
|
||||
t.Errorf("expected %v but got %v", tc.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@ -3,8 +3,6 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -12,176 +10,6 @@ import (
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
join0, join1, cut := strings.Cut(join, ".")
|
||||
if cut {
|
||||
if _, ok := relationships.Relations[join0]; ok && name == join0 {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, join1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
reflectValue := rel.FieldSchema.MakeSlice().Elem()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if frv.Kind() != reflect.Ptr {
|
||||
reflectValue = reflect.Append(reflectValue, frv.Addr())
|
||||
} else {
|
||||
if frv.IsNil() {
|
||||
continue
|
||||
}
|
||||
reflectValue = reflect.Append(reflectValue, frv)
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Pointer:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
@ -275,8 +103,6 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
@ -285,11 +111,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
}
|
||||
}
|
||||
|
||||
if len(inlineConds) > 0 {
|
||||
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||
}
|
||||
|
||||
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,12 +3,11 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
@ -25,10 +24,6 @@ func Query(db *gorm.DB) {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, 0)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -114,148 +109,78 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
guessNestedRelations = append(guessNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if db.Statement.Schema == nil {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = guessNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if join.Expression != nil {
|
||||
return clause.Join{
|
||||
Type: join.JoinType,
|
||||
Expression: join.Expression,
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for idx, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
curAliasName := rel.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||
}
|
||||
|
||||
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||
aliasName := curAliasName
|
||||
if idx == len(relations)-1 && join.Alias != "" {
|
||||
aliasName = join.Alias
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||
specifiedRelationsName[curAliasName] = aliasName
|
||||
}
|
||||
|
||||
parentTableName = curAliasName
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
})
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
})
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
@ -264,6 +189,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
db.Statement.Joins = nil
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
@ -281,27 +207,60 @@ func Preload(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
||||
if rel.Schema == db.Statement.Schema {
|
||||
if _, ok := preloadMap[rel.Name]; !ok {
|
||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
preloadDB.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||
return
|
||||
}
|
||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
for _, name := range preloadNames {
|
||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause := db.Statement.Clauses["FROM"]
|
||||
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
|
||||
db.Statement.Clauses["FROM"] = fromClause
|
||||
}
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
|
||||
@ -13,10 +13,5 @@ func RawExec(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun || db.Error != nil {
|
||||
if db.DryRun {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -72,7 +72,6 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
@ -92,10 +91,6 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
db.AddError(rows.Close())
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
@ -103,11 +98,6 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -147,9 +137,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
@ -243,7 +231,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
@ -255,13 +243,11 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
}
|
||||
default:
|
||||
updatingSchema := stmt.Schema
|
||||
var isDiffSchema bool
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
// different schema
|
||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||
updatingSchema = updatingStmt.Schema
|
||||
isDiffSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,7 +263,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
@ -288,13 +274,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
|
||||
if (ok || !isZero) && field.Updatable {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignField := field
|
||||
if isDiffSchema {
|
||||
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||
assignField = originField
|
||||
}
|
||||
}
|
||||
assignValue(assignField, value)
|
||||
assignValue(field, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
36
callbacks/visit_map_test.go
Normal file
36
callbacks/visit_map_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
||||
210
chainable_api.go
210
chainable_api.go
@ -10,11 +10,10 @@ import (
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
@ -22,19 +21,6 @@ func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
//
|
||||
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||
// [docs] for more depth.
|
||||
//
|
||||
// // add a simple limit clause
|
||||
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||
// // tell the optimizer to use the `idx_user_name` index
|
||||
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||
// // specify the lock strength to UPDATE
|
||||
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
@ -55,22 +41,15 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||
var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
//
|
||||
// // Get a user
|
||||
// db.Table("users").Take(&result)
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||
if results[1] != "" {
|
||||
tx.Statement.Table = results[1]
|
||||
} else {
|
||||
tx.Statement.Table = results[2]
|
||||
}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
||||
tx.Statement.Table = results[1]
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
@ -86,11 +65,6 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
//
|
||||
// // Select distinct names of users
|
||||
// db.Distinct("name").Find(&results)
|
||||
// // Select distinct name/age pairs from users
|
||||
// db.Distinct("name", "age").Find(&results)
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
@ -101,14 +75,6 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
//
|
||||
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||
// Select accepts both string arguments and arrays.
|
||||
//
|
||||
// // Select name and age of user using multiple arguments
|
||||
// db.Select("name", "age").Find(&users)
|
||||
// // Select name and age of user using an array
|
||||
// db.Select([]string{"name", "age"}).Find(&users)
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -185,25 +151,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
|
||||
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.ColumnMapping = m
|
||||
return
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
//
|
||||
// // Find the first user with name jinzhu
|
||||
// db.Where("name = ?", "jinzhu").First(&user)
|
||||
// // Find the first user with name jinzhu and age 20
|
||||
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||
// // Find the first user with name jinzhu and age not equal to 20
|
||||
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -213,11 +161,6 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
//
|
||||
// Not works similarly to where, and has the same syntax.
|
||||
//
|
||||
// // Find the first user with name not equal to jinzhu
|
||||
// db.Not("name = ?", "jinzhu").First(&user)
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -227,11 +170,6 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
//
|
||||
// Or is used to chain together queries with an OR.
|
||||
//
|
||||
// // Find the first user with name equal to jinzhu or john
|
||||
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
@ -241,45 +179,26 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.LeftJoin, query, args...)
|
||||
}
|
||||
|
||||
// InnerJoins specify inner joins conditions
|
||||
// db.InnerJoins("Account").Find(&user)
|
||||
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.InnerJoin, query, args...)
|
||||
}
|
||||
|
||||
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
j := join{
|
||||
Name: query, Conds: args, Selects: db.Statement.Selects,
|
||||
Omits: db.Statement.Omits, JoinType: joinType,
|
||||
}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
|
||||
return
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
//
|
||||
// // Select the sum age of users with given names
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -291,9 +210,6 @@ func (db *DB) Group(name string) (tx *DB) {
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
//
|
||||
// // Select the sum age of users with name jinzhu
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
@ -302,20 +218,13 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieving records from database
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
|
||||
// {Column: clause.Column{Name: "name"}, Desc: true},
|
||||
// {Column: clause.Column{Name: "age"}, Desc: true},
|
||||
// }})
|
||||
// Order specify order when retrieve records from database
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := value.(type) {
|
||||
case clause.OrderBy:
|
||||
tx.Statement.AddClause(v)
|
||||
case clause.OrderByColumn:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{v},
|
||||
@ -333,27 +242,13 @@ func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
//
|
||||
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||
//
|
||||
// // retrieve 3 users
|
||||
// db.Limit(3).Find(&users)
|
||||
// // retrieve 3 users into users1, and all users into users2
|
||||
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
//
|
||||
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||
//
|
||||
// // select the third user
|
||||
// db.Offset(2).First(&user)
|
||||
// // select the first user by cancelling an earlier chained offset
|
||||
// db.Offset(5).Offset(-1).First(&user)
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
@ -361,37 +256,25 @@ func (db *DB) Offset(offset int) (tx *DB) {
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
//
|
||||
// // get all users, and preload all non-cancelled orders
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
@ -401,57 +284,18 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Attrs only adds attributes if the record is not found.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign an email if the record is not found, otherwise ignore provided email
|
||||
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||
// records will be updated even if they are found.
|
||||
//
|
||||
// // assign an email regardless of if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Unscoped disables the global scope of soft deletion in a query.
|
||||
// By default, GORM uses soft deletion, marking records as "deleted"
|
||||
// by setting a timestamp on a specific field (e.g., `deleted_at`).
|
||||
// Unscoped allows queries to include records marked as deleted,
|
||||
// overriding the soft deletion behavior.
|
||||
// Example:
|
||||
//
|
||||
// var users []User
|
||||
// db.Unscoped().Find(&users)
|
||||
// // Retrieves all users, including deleted ones.
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
|
||||
@ -29,7 +29,6 @@ func BenchmarkSelect(b *testing.B) {
|
||||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
limit10 := 10
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{
|
||||
@ -44,7 +43,7 @@ func BenchmarkComplexSelect(b *testing.B) {
|
||||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||
}},
|
||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||
clause.Limit{Limit: &limit10, Offset: 20},
|
||||
clause.Limit{Limit: 10, Offset: 20},
|
||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||
}
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ type Builder interface {
|
||||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
AddError(error) error
|
||||
}
|
||||
|
||||
// Clause
|
||||
|
||||
@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) {
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = name[:0]
|
||||
name = []byte{}
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
@ -246,19 +246,15 @@ func (eq Eq) Build(builder Builder) {
|
||||
|
||||
switch eq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
builder.WriteString(" IN (")
|
||||
rv := reflect.ValueOf(eq.Value)
|
||||
if rv.Len() == 0 {
|
||||
builder.WriteString(" IN (NULL)")
|
||||
} else {
|
||||
builder.WriteString(" IN (")
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
default:
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
|
||||
@ -199,11 +199,6 @@ func TestExpression(t *testing.T) {
|
||||
},
|
||||
ExpectedVars: []interface{}{"a", "b"},
|
||||
Result: "`column-name` NOT IN (?,?)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: column, Value: []string{}},
|
||||
},
|
||||
Result: "`column-name` IN (NULL)",
|
||||
}, {
|
||||
Expressions: []clause.Expression{
|
||||
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
package clause
|
||||
|
||||
import "gorm.io/gorm/utils"
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
@ -11,31 +9,7 @@ const (
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
type JoinTarget struct {
|
||||
Type JoinType
|
||||
Association string
|
||||
Subquery Expression
|
||||
Table string
|
||||
}
|
||||
|
||||
func Has(name string) JoinTarget {
|
||||
return JoinTarget{Type: InnerJoin, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) Association(name string) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||
}
|
||||
|
||||
func (jt JoinTarget) As(name string) JoinTarget {
|
||||
jt.Table = name
|
||||
return jt
|
||||
}
|
||||
|
||||
// Join clause for from
|
||||
// Join join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
@ -44,12 +18,6 @@ type Join struct {
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func JoinTable(names ...string) Table {
|
||||
return Table{
|
||||
Name: utils.JoinNestedRelationNames(names),
|
||||
}
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
if join.Expression != nil {
|
||||
join.Expression.Build(builder)
|
||||
|
||||
@ -1,101 +0,0 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
results := []struct {
|
||||
name string
|
||||
join clause.Join
|
||||
sql string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.RightJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "CROSS JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.CrossJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "USING",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
{
|
||||
name: "Expression",
|
||||
join: clause.Join{
|
||||
// Invalid
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
// Valid
|
||||
Expression: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
}
|
||||
for _, result := range results {
|
||||
t.Run(result.name, func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
result.join.Build(stmt)
|
||||
if result.sql != stmt.SQL.String() {
|
||||
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,10 @@
|
||||
package clause
|
||||
|
||||
import "strconv"
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit *int
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
@ -13,16 +15,16 @@ func (limit Limit) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
if limit.Limit > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.AddVar(builder, *limit.Limit)
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
if limit.Limit > 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.AddVar(builder, limit.Offset)
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
|
||||
@ -8,10 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
limit0 := 0
|
||||
limit10 := 10
|
||||
limit50 := 50
|
||||
limitNeg10 := -10
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
@ -19,56 +15,38 @@ func TestLimit(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||
Limit: &limit10,
|
||||
Limit: 10,
|
||||
Offset: 20,
|
||||
}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 20},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit0},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit0},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{20},
|
||||
"SELECT * FROM `users` OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{30},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 20},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit10, 30},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT ?",
|
||||
[]interface{}{limit10},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT 10", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
|
||||
"SELECT * FROM `users` OFFSET ?",
|
||||
[]interface{}{30},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
|
||||
"SELECT * FROM `users` LIMIT ? OFFSET ?",
|
||||
[]interface{}{limit50, 30},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -1,12 +1,5 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
|
||||
@ -14,21 +14,17 @@ func TestLocking(t *testing.T) {
|
||||
Vars []interface{}
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
|
||||
"SELECT * FROM `users` FOR UPDATE", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
|
||||
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
|
||||
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
|
||||
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
||||
@ -16,27 +16,27 @@ func (OnConflict) Name() string {
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
|
||||
@ -26,12 +26,9 @@ func (returning Returning) Build(builder Builder) {
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (returning Returning) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
|
||||
if v.Columns != nil {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
} else {
|
||||
returning.Columns = nil
|
||||
}
|
||||
if v, ok := clause.Expression.(Returning); ok {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = returning
|
||||
}
|
||||
|
||||
@ -26,22 +26,6 @@ func TestReturning(t *testing.T) {
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Returning{}, clause.Returning{
|
||||
[]clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING *", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Returning{
|
||||
[]clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
}, clause.Returning{}},
|
||||
"SELECT * FROM `users` RETURNING *", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
||||
@ -49,18 +49,16 @@ func TestSelect(t *testing.T) {
|
||||
Exprs: []clause.Expression{
|
||||
clause.Expr{
|
||||
SQL: "? as name",
|
||||
Vars: []interface{}{
|
||||
clause.Eq{
|
||||
Column: clause.Column{Name: "age"},
|
||||
Value: 18,
|
||||
},
|
||||
Vars: []interface{}{clause.Eq{
|
||||
Column: clause.Column{Name: "age"},
|
||||
Value: 18,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, clause.From{}},
|
||||
"SELECT `age` = ? as name FROM `users`",
|
||||
[]interface{}{18},
|
||||
"SELECT `age` = ? as name FROM `users`", []interface{}{18},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -21,12 +21,6 @@ func (where Where) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
@ -153,11 +147,6 @@ func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
@ -166,63 +155,19 @@ type NotConditions struct {
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
anyNegationBuilder := false
|
||||
for _, c := range not.Exprs {
|
||||
if _, ok := c.(NegationExpressionBuilder); ok {
|
||||
anyNegationBuilder = true
|
||||
break
|
||||
}
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
if anyNegationBuilder {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
switch c.(type) {
|
||||
case OrConditions:
|
||||
builder.WriteString(OrWithSpace)
|
||||
default:
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
@ -237,9 +182,9 @@ func (not NotConditions) Build(builder Builder) {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
|
||||
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
|
||||
[]interface{}{18, "jinzhu"},
|
||||
},
|
||||
{
|
||||
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
|
||||
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
|
||||
},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
|
||||
[]interface{}{"1", 100},
|
||||
},
|
||||
{
|
||||
@ -105,30 +105,6 @@ func TestWhere(t *testing.T) {
|
||||
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
|
||||
[]interface{}{"1", 100},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
|
||||
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
|
||||
}},
|
||||
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
|
||||
[]interface{}{100, 60},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Not(clause.AndConditions{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
|
||||
clause.Gt{Column: "age", Value: 18},
|
||||
}}, clause.OrConditions{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Lt{Column: "score", Value: 100},
|
||||
},
|
||||
}),
|
||||
}}},
|
||||
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
|
||||
[]interface{}{"1", 18, 100},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
||||
10
errors.go
10
errors.go
@ -21,10 +21,6 @@ var (
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||
// ErrSubQueryRequired sub query required
|
||||
ErrSubQueryRequired = errors.New("sub query required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
@ -45,10 +41,4 @@ var (
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
// ErrCheckConstraintViolated occurs when there is a check constraint violation
|
||||
ErrCheckConstraintViolated = errors.New("violates check constraint")
|
||||
)
|
||||
|
||||
180
finisher_api.go
180
finisher_api.go
@ -1,11 +1,9 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
@ -35,10 +33,9 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
var rowsAffected int64
|
||||
tx = db.getInstance()
|
||||
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
|
||||
callFc := func(tx *DB) error {
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
for i := 0; i < reflectLen; i += batchSize {
|
||||
ends := i + batchSize
|
||||
if ends > reflectLen {
|
||||
@ -56,7 +53,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
|
||||
if tx.SkipDefaultTransaction {
|
||||
tx.AddError(callFc(tx.Session(&Session{})))
|
||||
} else {
|
||||
tx.AddError(tx.Transaction(callFc))
|
||||
@ -104,13 +101,14 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||
tx = tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 {
|
||||
return tx.Create(value)
|
||||
}
|
||||
}
|
||||
|
||||
return updateTx
|
||||
}
|
||||
|
||||
return
|
||||
@ -187,9 +185,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
totalSize = limit.Limit
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
@ -233,11 +229,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
break
|
||||
}
|
||||
|
||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
if zero {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
@ -296,16 +288,6 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
|
||||
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
@ -333,69 +315,50 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
|
||||
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
result := queryTx.Find(dest, conds...)
|
||||
if result.Error != nil {
|
||||
tx.Error = result.Error
|
||||
return tx
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
if result := queryTx.Find(dest, conds...); result.Error == nil {
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
} else {
|
||||
tx.Error = result.Error
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
@ -537,7 +500,6 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
tx.AddError(rows.Err())
|
||||
}
|
||||
tx.AddError(rows.Close())
|
||||
}
|
||||
@ -550,9 +512,8 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||
//
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model != nil {
|
||||
@ -625,15 +586,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
spID := new(maphash.Hash).Sum64()
|
||||
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%d", spID))
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -674,18 +635,11 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
ctx := tx.Statement.Context
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
if db.Config.DefaultTransactionTimeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
@ -721,21 +675,7 @@ func (db *DB) Rollback() *DB {
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
@ -744,21 +684,7 @@ func (db *DB) SavePoint(name string) *DB {
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because RollbackTo not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
|
||||
605
generics.go
605
generics.go
@ -1,605 +0,0 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type result struct {
|
||||
Result sql.Result
|
||||
RowsAffected int64
|
||||
}
|
||||
|
||||
func (info *result) ModifyStatement(stmt *Statement) {
|
||||
stmt.Result = info
|
||||
}
|
||||
|
||||
// Build implements clause.Expression interface
|
||||
func (result) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func WithResult() *result {
|
||||
return &result{}
|
||||
}
|
||||
|
||||
type Interface[T any] interface {
|
||||
Raw(sql string, values ...interface{}) ExecInterface[T]
|
||||
Exec(ctx context.Context, sql string, values ...interface{}) error
|
||||
CreateInterface[T]
|
||||
}
|
||||
|
||||
type CreateInterface[T any] interface {
|
||||
ChainInterface[T]
|
||||
Table(name string, args ...interface{}) CreateInterface[T]
|
||||
Create(ctx context.Context, r *T) error
|
||||
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
|
||||
}
|
||||
|
||||
type ChainInterface[T any] interface {
|
||||
ExecInterface[T]
|
||||
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
|
||||
Where(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Not(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Limit(offset int) ChainInterface[T]
|
||||
Offset(offset int) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
|
||||
Select(query string, args ...interface{}) ChainInterface[T]
|
||||
Omit(columns ...string) ChainInterface[T]
|
||||
MapColumns(m map[string]string) ChainInterface[T]
|
||||
Distinct(args ...interface{}) ChainInterface[T]
|
||||
Group(name string) ChainInterface[T]
|
||||
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Order(value interface{}) ChainInterface[T]
|
||||
|
||||
Build(builder clause.Builder)
|
||||
|
||||
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||
Count(ctx context.Context, column string) (result int64, err error)
|
||||
}
|
||||
|
||||
type ExecInterface[T any] interface {
|
||||
Scan(ctx context.Context, r interface{}) error
|
||||
First(context.Context) (T, error)
|
||||
Last(ctx context.Context) (T, error)
|
||||
Take(context.Context) (T, error)
|
||||
Find(ctx context.Context) ([]T, error)
|
||||
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
|
||||
Row(ctx context.Context) *sql.Row
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type JoinBuilder interface {
|
||||
Select(...string) JoinBuilder
|
||||
Omit(...string) JoinBuilder
|
||||
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||
}
|
||||
|
||||
type PreloadBuilder interface {
|
||||
Select(...string) PreloadBuilder
|
||||
Omit(...string) PreloadBuilder
|
||||
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Limit(offset int) PreloadBuilder
|
||||
Offset(offset int) PreloadBuilder
|
||||
Order(value interface{}) PreloadBuilder
|
||||
LimitPerRecord(num int) PreloadBuilder
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
|
||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
v := &g[T]{
|
||||
db: db,
|
||||
ops: make([]op, 0, 5),
|
||||
}
|
||||
|
||||
if len(opts) > 0 {
|
||||
v.ops = append(v.ops, func(db *DB) *DB {
|
||||
return db.Clauses(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
v.createG = &createG[T]{
|
||||
chainG: chainG[T]{
|
||||
execG: execG[T]{g: v},
|
||||
},
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type g[T any] struct {
|
||||
*createG[T]
|
||||
db *DB
|
||||
ops []op
|
||||
}
|
||||
|
||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||
db := g.db
|
||||
if !db.DryRun {
|
||||
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
}
|
||||
|
||||
for _, op := range g.ops {
|
||||
db = op(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||
return execG[T]{g: &g[T]{
|
||||
db: c.db,
|
||||
ops: append(c.ops, func(db *DB) *DB {
|
||||
return db.Raw(sql, values...)
|
||||
}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||
return c.apply(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
type createG[T any] struct {
|
||||
chainG[T]
|
||||
}
|
||||
|
||||
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||
return createG[T]{c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})}
|
||||
}
|
||||
|
||||
func (c createG[T]) Create(ctx context.Context, r *T) error {
|
||||
return c.g.apply(ctx).Create(r).Error
|
||||
}
|
||||
|
||||
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
|
||||
}
|
||||
|
||||
type chainG[T any] struct {
|
||||
execG[T]
|
||||
}
|
||||
|
||||
func (c chainG[T]) getInstance() *DB {
|
||||
var r T
|
||||
return c.g.apply(context.Background()).Model(r).getInstance()
|
||||
}
|
||||
|
||||
func (c chainG[T]) with(v op) chainG[T] {
|
||||
return chainG[T]{
|
||||
execG: execG[T]{g: &g[T]{
|
||||
db: c.g.db,
|
||||
ops: append(append([]op(nil), c.g.ops...), v),
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
for _, fc := range scopes {
|
||||
fc(db.Statement)
|
||||
}
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Where(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Not(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Or(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Limit(offset)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Offset(offset)
|
||||
})
|
||||
}
|
||||
|
||||
type joinBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
type preloadBuilder struct {
|
||||
limitPerRecord int
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
q.db.Limit(limit)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
q.db.Offset(offset)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
q.db.Order(value)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||
q.limitPerRecord = num
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
if jt.Table == "" {
|
||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||
}
|
||||
|
||||
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||
if on != nil {
|
||||
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
j := join{
|
||||
Name: jt.Association,
|
||||
Alias: jt.Table,
|
||||
Selects: q.db.Statement.Selects,
|
||||
Omits: q.db.Statement.Omits,
|
||||
JoinType: jt.Type,
|
||||
}
|
||||
|
||||
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
|
||||
if jt.Subquery != nil {
|
||||
joinType := j.JoinType
|
||||
if joinType == "" {
|
||||
joinType = clause.LeftJoin
|
||||
}
|
||||
|
||||
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
|
||||
stmt := db.getInstance().Statement
|
||||
if len(j.Selects) == 0 {
|
||||
j.Selects = stmt.Selects
|
||||
}
|
||||
if len(j.Omits) == 0 {
|
||||
j.Omits = stmt.Omits
|
||||
}
|
||||
}
|
||||
|
||||
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||
|
||||
if j.On != nil {
|
||||
expr.SQL += " ON ?"
|
||||
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||
}
|
||||
|
||||
j.Expression = expr
|
||||
}
|
||||
|
||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||
})
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Select(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Omit(columns...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.MapColumns(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Distinct(args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Group(name string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Group(name)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Having(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Order(value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Preload(association, func(tx *DB) *DB {
|
||||
q := preloadBuilder{db: tx.getInstance()}
|
||||
if query != nil {
|
||||
if err := query(&q); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||
if !ok {
|
||||
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||
relationships := db.Statement.Schema.Relationships
|
||||
for _, field := range preloadFields {
|
||||
var ok bool
|
||||
relation, ok = relationships.Relations[field]
|
||||
if ok {
|
||||
relationships = relation.FieldSchema.Relationships
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if q.limitPerRecord > 0 {
|
||||
if relation.JoinTable != nil {
|
||||
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
|
||||
return tx
|
||||
}
|
||||
|
||||
refColumns := []clause.Column{}
|
||||
for _, rel := range relation.References {
|
||||
if rel.OwnPrimaryKey {
|
||||
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||
}
|
||||
}
|
||||
|
||||
if len(refColumns) != 0 {
|
||||
selectExpr := clause.CommaExpression{}
|
||||
for _, column := range q.db.Statement.Selects {
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||
}
|
||||
|
||||
if len(selectExpr.Exprs) == 0 {
|
||||
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||
}
|
||||
|
||||
partitionBy := clause.CommaExpression{}
|
||||
for _, column := range refColumns {
|
||||
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||
}
|
||||
|
||||
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||
vars := []interface{}{partitionBy}
|
||||
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||
vars = append(vars, orderBy)
|
||||
} else {
|
||||
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}})
|
||||
}
|
||||
vars = append(vars, rnnColumn)
|
||||
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||
|
||||
q.db.Clauses(clause.Select{Expression: selectExpr})
|
||||
|
||||
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return q.db
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||
r := new(T)
|
||||
res := c.g.apply(ctx).Delete(r)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||
var r T
|
||||
res := c.g.apply(ctx).Model(r).Update(name, value)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||
res := c.g.apply(ctx).Updates(t)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||
var r T
|
||||
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (c chainG[T]) Build(builder clause.Builder) {
|
||||
subdb := c.getInstance()
|
||||
subdb.Logger = logger.Discard
|
||||
subdb.DryRun = true
|
||||
|
||||
if stmt, ok := builder.(*Statement); ok {
|
||||
if subdb.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = subdb.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
builder.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
}
|
||||
}
|
||||
|
||||
type execG[T any] struct {
|
||||
g *g[T]
|
||||
}
|
||||
|
||||
func (g execG[T]) First(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).First(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Model(r).Find(result).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (g execG[T]) Last(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Last(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Take(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Take(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||
var r []T
|
||||
err := g.g.apply(ctx).Find(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||
var data []T
|
||||
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||
return fc(data, batch)
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (g execG[T]) Row(ctx context.Context) *sql.Row {
|
||||
return g.g.apply(ctx).Row()
|
||||
}
|
||||
|
||||
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||
return g.g.apply(ctx).Rows()
|
||||
}
|
||||
5
go.mod
5
go.mod
@ -1,9 +1,8 @@
|
||||
module gorm.io/gorm
|
||||
|
||||
go 1.18
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
golang.org/x/text v0.20.0
|
||||
github.com/jinzhu/now v1.1.4
|
||||
)
|
||||
|
||||
6
go.sum
6
go.sum
@ -1,6 +1,4 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
|
||||
136
gorm.go
136
gorm.go
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
@ -21,9 +20,7 @@ const preparedStmtDBKey = "preparedStmt"
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
DefaultTransactionTimeout time.Duration
|
||||
|
||||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
@ -36,17 +33,10 @@ type Config struct {
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// PrepareStmt cache support LRU expired,
|
||||
// default maxsize=int64 Max value and ttl=1h
|
||||
PrepareStmtMaxSize int
|
||||
PrepareStmtTTL time.Duration
|
||||
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// IgnoreRelationshipsWhenMigrating
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
@ -55,10 +45,6 @@ type Config struct {
|
||||
QueryFields bool
|
||||
// CreateBatchSize default create batch size
|
||||
CreateBatchSize int
|
||||
// TranslateError enabling error translation
|
||||
TranslateError bool
|
||||
// PropagateUnscoped propagate Unscoped to every other nested statement
|
||||
PropagateUnscoped bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
@ -119,7 +105,6 @@ type Session struct {
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
PropagateUnscoped bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
@ -137,24 +122,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
return isConfig && !isConfig2
|
||||
})
|
||||
|
||||
if len(opts) > 0 {
|
||||
if c, ok := opts[0].(*Config); ok {
|
||||
config = c
|
||||
} else {
|
||||
opts = append([]Option{config}, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
var skipAfterInitialize bool
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if applyErr := opt.Apply(config); applyErr != nil {
|
||||
return nil, applyErr
|
||||
}
|
||||
defer func(opt Option) {
|
||||
if skipAfterInitialize {
|
||||
return
|
||||
}
|
||||
if errr := opt.AfterInitialize(db); errr != nil {
|
||||
err = errr
|
||||
}
|
||||
@ -169,7 +142,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||
config.NamingStrategy = schema.NamingStrategy{}
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
@ -202,26 +175,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
if err != nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
// DB is not initialized, so we skip AfterInitialize
|
||||
skipAfterInitialize = true
|
||||
return
|
||||
}
|
||||
|
||||
if config.TranslateError {
|
||||
if _, ok := db.Dialector.(ErrorTranslator); !ok {
|
||||
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preparedStmt := &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
Stmts: map[string]Stmt{},
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
@ -272,10 +236,6 @@ func (db *DB) Session(config *Session) *DB {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.PropagateUnscoped {
|
||||
txConfig.PropagateUnscoped = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
@ -286,30 +246,16 @@ func (db *DB) Session(config *Session) *DB {
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
var preparedStmt *PreparedStmtDB
|
||||
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
preparedStmt := v.(*PreparedStmtDB)
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
@ -391,18 +337,10 @@ func (db *DB) Callback() *callbacks {
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if db.Config.TranslateError {
|
||||
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||
err = errTranslator.Translate(err)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else if err != nil {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
@ -410,20 +348,12 @@ func (db *DB) AddError(err error) error {
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||
return sqldb, err
|
||||
}
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
@ -437,15 +367,11 @@ func (db *DB) getInstance() *DB {
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
}
|
||||
if db.Config.PropagateUnscoped {
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
@ -530,14 +456,14 @@ func (db *DB) Use(plugin Plugin) error {
|
||||
|
||||
// ToSQL for generate SQL string.
|
||||
//
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
|
||||
@ -26,10 +26,6 @@ type Plugin interface {
|
||||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
type ParamsFilter interface {
|
||||
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
@ -86,7 +82,3 @@ type Rows interface {
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ErrorTranslator interface {
|
||||
Translate(err error) error
|
||||
}
|
||||
|
||||
@ -1,493 +0,0 @@
|
||||
package lru
|
||||
|
||||
// golang -lru
|
||||
// https://github.com/hashicorp/golang-lru
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EvictCallback is used to get a callback when a cache entry is evicted
|
||||
type EvictCallback[K comparable, V any] func(key K, value V)
|
||||
|
||||
// LRU implements a thread-safe LRU with expirable entries.
|
||||
type LRU[K comparable, V any] struct {
|
||||
size int
|
||||
evictList *LruList[K, V]
|
||||
items map[K]*Entry[K, V]
|
||||
onEvict EvictCallback[K, V]
|
||||
|
||||
// expirable options
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
done chan struct{}
|
||||
|
||||
// buckets for expiration
|
||||
buckets []bucket[K, V]
|
||||
// uint8 because it's number between 0 and numBuckets
|
||||
nextCleanupBucket uint8
|
||||
}
|
||||
|
||||
// bucket is a container for holding entries to be expired
|
||||
type bucket[K comparable, V any] struct {
|
||||
entries map[K]*Entry[K, V]
|
||||
newestEntry time.Time
|
||||
}
|
||||
|
||||
// noEvictionTTL - very long ttl to prevent eviction
|
||||
const noEvictionTTL = time.Hour * 24 * 365 * 10
|
||||
|
||||
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
|
||||
// casting it as uint8 explicitly requires type conversions in multiple places
|
||||
const numBuckets = 100
|
||||
|
||||
// NewLRU returns a new thread-safe cache with expirable entries.
|
||||
//
|
||||
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
|
||||
//
|
||||
// Providing 0 TTL turns expiring off.
|
||||
//
|
||||
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
|
||||
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = noEvictionTTL
|
||||
}
|
||||
|
||||
res := LRU[K, V]{
|
||||
ttl: ttl,
|
||||
size: size,
|
||||
evictList: NewList[K, V](),
|
||||
items: make(map[K]*Entry[K, V]),
|
||||
onEvict: onEvict,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// initialize the buckets
|
||||
res.buckets = make([]bucket[K, V], numBuckets)
|
||||
for i := 0; i < numBuckets; i++ {
|
||||
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
|
||||
}
|
||||
|
||||
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
|
||||
//
|
||||
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
|
||||
// it's decided to add functionality to close it in the version later than v2.
|
||||
if res.ttl != noEvictionTTL {
|
||||
go func(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(res.ttl / numBuckets)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
res.deleteExpired()
|
||||
}
|
||||
}
|
||||
}(res.done)
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
// Purge clears the cache completely.
|
||||
// onEvict is called for each evicted key.
|
||||
func (c *LRU[K, V]) Purge() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, v := range c.items {
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(k, v.Value)
|
||||
}
|
||||
delete(c.items, k)
|
||||
}
|
||||
for _, b := range c.buckets {
|
||||
for _, ent := range b.entries {
|
||||
delete(b.entries, ent.Key)
|
||||
}
|
||||
}
|
||||
c.evictList.Init()
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
// Returns false if there was no eviction: the item was already in the cache,
|
||||
// or the size was not exceeded.
|
||||
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
// Check for existing item
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.evictList.MoveToFront(ent)
|
||||
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
|
||||
ent.Value = value
|
||||
ent.ExpiresAt = now.Add(c.ttl)
|
||||
c.addToBucket(ent)
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new item
|
||||
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
|
||||
c.items[key] = ent
|
||||
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
|
||||
|
||||
evict := c.size > 0 && c.evictList.Length() > c.size
|
||||
// Verify size not exceeded
|
||||
if evict {
|
||||
c.removeOldest()
|
||||
}
|
||||
return evict
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
c.evictList.MoveToFront(ent)
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
func (c *LRU[K, V]) Contains(key K) (ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, ok = c.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache, returning if the
|
||||
// key was contained.
|
||||
func (c *LRU[K, V]) Remove(key K) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.removeElement(ent)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOldest returns the oldest entry
|
||||
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *LRU[K, V]) KeyValues() map[K]V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps := make(map[K]V)
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
maps[ent.Key] = ent.Value
|
||||
// keys = append(keys, ent.Key)
|
||||
}
|
||||
return maps
|
||||
}
|
||||
|
||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Keys() []K {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
keys := make([]K, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ent.Key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Values returns a slice of the values in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Values() []V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
values := make([]V, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
values = append(values, ent.Value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *LRU[K, V]) Len() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.evictList.Length()
|
||||
}
|
||||
|
||||
// Resize changes the cache size. Size of 0 means unlimited.
|
||||
func (c *LRU[K, V]) Resize(size int) (evicted int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if size <= 0 {
|
||||
c.size = 0
|
||||
return 0
|
||||
}
|
||||
diff := c.evictList.Length() - size
|
||||
if diff < 0 {
|
||||
diff = 0
|
||||
}
|
||||
for i := 0; i < diff; i++ {
|
||||
c.removeOldest()
|
||||
}
|
||||
c.size = size
|
||||
return diff
|
||||
}
|
||||
|
||||
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
|
||||
// func (c *LRU[K, V]) Close() {
|
||||
// c.mu.Lock()
|
||||
// defer c.mu.Unlock()
|
||||
// select {
|
||||
// case <-c.done:
|
||||
// return
|
||||
// default:
|
||||
// }
|
||||
// close(c.done)
|
||||
// }
|
||||
|
||||
// removeOldest removes the oldest item from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeOldest() {
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
}
|
||||
|
||||
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
|
||||
c.evictList.Remove(e)
|
||||
delete(c.items, e.Key)
|
||||
c.removeFromBucket(e)
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(e.Key, e.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
|
||||
// in it to expire first.
|
||||
func (c *LRU[K, V]) deleteExpired() {
|
||||
c.mu.Lock()
|
||||
bucketIdx := c.nextCleanupBucket
|
||||
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
|
||||
// wait for newest entry to expire before cleanup without holding lock
|
||||
if timeToExpire > 0 {
|
||||
c.mu.Unlock()
|
||||
time.Sleep(timeToExpire)
|
||||
c.mu.Lock()
|
||||
}
|
||||
for _, ent := range c.buckets[bucketIdx].entries {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
|
||||
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
|
||||
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
|
||||
e.ExpireBucket = bucketID
|
||||
c.buckets[bucketID].entries[e.Key] = e
|
||||
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
|
||||
c.buckets[bucketID].newestEntry = e.ExpiresAt
|
||||
}
|
||||
}
|
||||
|
||||
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
|
||||
delete(c.buckets[e.ExpireBucket].entries, e.Key)
|
||||
}
|
||||
|
||||
// Cap returns the capacity of the cache
|
||||
func (c *LRU[K, V]) Cap() int {
|
||||
return c.size
|
||||
}
|
||||
|
||||
// Entry is an LRU Entry
|
||||
type Entry[K comparable, V any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Entry[K, V]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *LruList[K, V]
|
||||
|
||||
// The LRU Key of this element.
|
||||
Key K
|
||||
|
||||
// The Value stored with this element.
|
||||
Value V
|
||||
|
||||
// The time this element would be cleaned up, optional
|
||||
ExpiresAt time.Time
|
||||
|
||||
// The expiry bucket item was put in, optional
|
||||
ExpireBucket uint8
|
||||
}
|
||||
|
||||
// PrevEntry returns the previous list element or nil.
|
||||
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LruList represents a doubly linked list.
|
||||
// The zero Value for LruList is an empty list ready to use.
|
||||
type LruList[K comparable, V any] struct {
|
||||
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list Length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *LruList[K, V]) Init() *LruList[K, V] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// NewList returns an initialized list.
|
||||
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
|
||||
|
||||
// Length returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *LruList[K, V]) Length() int { return l.len }
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *LruList[K, V]) Back() *Entry[K, V] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List Value.
|
||||
func (l *LruList[K, V]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
|
||||
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
|
||||
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
|
||||
}
|
||||
|
||||
// Remove removes e from its list, decrements l.len
|
||||
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, time.Time{}, &l.root)
|
||||
}
|
||||
|
||||
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, expiresAt, &l.root)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
||||
@ -1,183 +0,0 @@
|
||||
package stmt_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/lru"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Error() error {
|
||||
return stmt.prepareErr
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Close() error {
|
||||
<-stmt.prepared
|
||||
|
||||
if stmt.Stmt != nil {
|
||||
return stmt.Stmt.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
|
||||
// This interface provides methods for creating new statements, retrieving all cache keys,
|
||||
// getting cached statements, setting cached statements, and deleting cached statements.
|
||||
type Store interface {
|
||||
// New creates a new Stmt object and caches it.
|
||||
// Parameters:
|
||||
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
|
||||
// connPool: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
// Returns:
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
|
||||
|
||||
// Keys returns a slice of all cache keys in the store.
|
||||
Keys() []string
|
||||
|
||||
// Get retrieves a Stmt object from the store based on the given key.
|
||||
// Parameters:
|
||||
// key: The key used to look up the Stmt object.
|
||||
// Returns:
|
||||
// *Stmt: The found Stmt object, or nil if not found.
|
||||
// bool: Indicates whether the corresponding Stmt object was successfully found.
|
||||
Get(key string) (*Stmt, bool)
|
||||
|
||||
// Set stores the given Stmt object in the store and associates it with the specified key.
|
||||
// Parameters:
|
||||
// key: The key used to associate the Stmt object.
|
||||
// value: The Stmt object to be stored.
|
||||
Set(key string, value *Stmt)
|
||||
|
||||
// Delete removes the Stmt object corresponding to the specified key from the store.
|
||||
// Parameters:
|
||||
// key: The key associated with the Stmt object to be deleted.
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// defaultMaxSize defines the default maximum capacity of the cache.
|
||||
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
|
||||
// the cache can theoretically store as many elements as possible.
|
||||
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
|
||||
const (
|
||||
defaultMaxSize = math.MaxInt
|
||||
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
|
||||
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
|
||||
defaultTTL = time.Hour * 24
|
||||
)
|
||||
|
||||
// New creates and returns a new Store instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
|
||||
// it defaults to defaultMaxSize.
|
||||
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
|
||||
// it defaults to defaultTTL.
|
||||
//
|
||||
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
|
||||
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
|
||||
// to release associated resources.
|
||||
//
|
||||
// Returns:
|
||||
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
|
||||
// eviction callback, and TTL.
|
||||
func New(size int, ttl time.Duration) Store {
|
||||
if size <= 0 {
|
||||
size = defaultMaxSize
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultTTL
|
||||
}
|
||||
|
||||
onEvicted := func(k string, v *Stmt) {
|
||||
if v != nil {
|
||||
go v.Close()
|
||||
}
|
||||
}
|
||||
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
}
|
||||
|
||||
type lruStore struct {
|
||||
lru *lru.LRU[string, *Stmt]
|
||||
}
|
||||
|
||||
func (s *lruStore) Keys() []string {
|
||||
return s.lru.Keys()
|
||||
}
|
||||
|
||||
func (s *lruStore) Get(key string) (*Stmt, bool) {
|
||||
stmt, ok := s.lru.Get(key)
|
||||
if ok && stmt != nil {
|
||||
<-stmt.prepared
|
||||
}
|
||||
return stmt, ok
|
||||
}
|
||||
|
||||
func (s *lruStore) Set(key string, value *Stmt) {
|
||||
s.lru.Add(key, value)
|
||||
}
|
||||
|
||||
func (s *lruStore) Delete(key string) {
|
||||
s.lru.Remove(key)
|
||||
}
|
||||
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// New creates a new Stmt object for executing SQL queries.
|
||||
// It caches the Stmt object for future use and handles preparation and error states.
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
|
||||
// conn: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
|
||||
// Create a Stmt object and set its Transaction property.
|
||||
// The prepared channel is used to synchronize the statement preparation state.
|
||||
cacheStmt := &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
// Cache the Stmt object with the associated key.
|
||||
s.Set(key, cacheStmt)
|
||||
// Unlock after completing initialization to prevent deadlocks.
|
||||
locker.Unlock()
|
||||
|
||||
// Ensure the prepared channel is closed after the function execution completes.
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Prepare the SQL statement using the provided connection.
|
||||
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
|
||||
if err != nil {
|
||||
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
|
||||
cacheStmt.prepareErr = err
|
||||
s.Delete(key)
|
||||
return &Stmt{}, err
|
||||
}
|
||||
|
||||
// Return the successfully prepared Stmt object.
|
||||
return cacheStmt, nil
|
||||
}
|
||||
@ -55,7 +55,6 @@ type Config struct {
|
||||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
IgnoreRecordNotFoundError bool
|
||||
ParameterizedQueries bool
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
@ -69,7 +68,7 @@ type Interface interface {
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard logger will print any log to io.Discard
|
||||
// Discard Discard logger will print any log to io.Discard
|
||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
@ -78,13 +77,8 @@ var (
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
// Recorder Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
|
||||
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
|
||||
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
return sql, params
|
||||
}
|
||||
)
|
||||
|
||||
// New initialize logger
|
||||
@ -134,30 +128,28 @@ func (l *logger) LogMode(level LogLevel) Interface {
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
//
|
||||
//nolint:cyclop
|
||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.LogLevel <= Silent {
|
||||
return
|
||||
}
|
||||
@ -189,14 +181,6 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
|
||||
}
|
||||
}
|
||||
|
||||
// ParamsFilter filter params
|
||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
return sql, params
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
@ -205,8 +189,8 @@ type traceRecorder struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// New trace recorder
|
||||
func (l *traceRecorder) New() *traceRecorder {
|
||||
// New new trace recorder
|
||||
func (l traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
@ -216,10 +200,3 @@ func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (s
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
l.Err = err
|
||||
}
|
||||
|
||||
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if RecorderParamsFilter == nil {
|
||||
return sql, params
|
||||
}
|
||||
return RecorderParamsFilter(ctx, sql, params...)
|
||||
}
|
||||
|
||||
@ -28,25 +28,10 @@ func isPrintable(s string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of Go types that should be converted to SQL primitives
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
func isNumeric(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
@ -92,28 +77,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = utils.ToString(v)
|
||||
case float32:
|
||||
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case float64, float32:
|
||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
@ -123,12 +106,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else if isNumeric(rv.Kind()) {
|
||||
if rv.CanInt() || rv.CanUint() {
|
||||
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||
} else {
|
||||
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||
}
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
@ -136,7 +113,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,24 +31,20 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
}
|
||||
|
||||
func format(v []byte, escaper string) string {
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
type role string
|
||||
type password []byte
|
||||
type intType int
|
||||
type floatType float64
|
||||
var (
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password("pass")
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
intVal intType = 1
|
||||
floatVal floatType = 1.23
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password([]byte("pass"))
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
)
|
||||
|
||||
results := []struct {
|
||||
@ -61,67 +57,43 @@ func TestExplainSQL(t *testing.T) {
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
|
||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
|
||||
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
|
||||
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
|
||||
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
|
||||
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
24
migrator.go
24
migrator.go
@ -13,7 +13,11 @@ func (db *DB) Migrator() Migrator {
|
||||
|
||||
// apply scopes to migrator
|
||||
for len(tx.Statement.scopes) > 0 {
|
||||
tx = tx.executeScopes()
|
||||
scopes := tx.Statement.scopes
|
||||
tx.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
tx = scope(tx)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||
@ -26,9 +30,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
|
||||
// ViewOption view option
|
||||
type ViewOption struct {
|
||||
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
|
||||
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
|
||||
Query *DB // required subquery.
|
||||
Replace bool
|
||||
CheckOption string
|
||||
Query *DB
|
||||
}
|
||||
|
||||
// ColumnType column type interface
|
||||
@ -56,14 +60,6 @@ type Index interface {
|
||||
Option() string
|
||||
}
|
||||
|
||||
// TableType table type interface
|
||||
type TableType interface {
|
||||
Schema() string
|
||||
Name() string
|
||||
Type() string
|
||||
Comment() (comment string, ok bool)
|
||||
}
|
||||
|
||||
// Migrator migrator interface
|
||||
type Migrator interface {
|
||||
// AutoMigrate
|
||||
@ -72,7 +68,6 @@ type Migrator interface {
|
||||
// Database
|
||||
CurrentDatabase() string
|
||||
FullDataTypeOf(*schema.Field) clause.Expr
|
||||
GetTypeAliases(databaseTypeName string) []string
|
||||
|
||||
// Tables
|
||||
CreateTable(dst ...interface{}) error
|
||||
@ -80,15 +75,12 @@ type Migrator interface {
|
||||
HasTable(dst interface{}) bool
|
||||
RenameTable(oldName, newName interface{}) error
|
||||
GetTables() (tableList []string, err error)
|
||||
TableType(dst interface{}) (TableType, error)
|
||||
|
||||
// Columns
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
|
||||
@ -17,12 +17,12 @@ func (idx Index) Table() string {
|
||||
return idx.TableName
|
||||
}
|
||||
|
||||
// Name return the name of the index.
|
||||
// Name return the name of the index.
|
||||
func (idx Index) Name() string {
|
||||
return idx.NameValue
|
||||
}
|
||||
|
||||
// Columns return the columns of the index
|
||||
// Columns return the columns fo the index
|
||||
func (idx Index) Columns() []string {
|
||||
return idx.ColumnList
|
||||
}
|
||||
@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) {
|
||||
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// Option return the optional attribute of the index
|
||||
// Option return the optional attribute fo the index
|
||||
func (idx Index) Option() string {
|
||||
return idx.OptionValue
|
||||
}
|
||||
|
||||
@ -7,28 +7,16 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||
// with a possible trailing non-digit character (\D?).
|
||||
|
||||
// For example, values that can pass this regular expression are:
|
||||
// - "123"
|
||||
// - "abc456"
|
||||
// -"%$#@789"
|
||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
|
||||
// TODO:? Create const vars for raw sql queries ?
|
||||
|
||||
var _ gorm.Migrator = (*Migrator)(nil)
|
||||
var (
|
||||
regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
)
|
||||
|
||||
// Migrator m struct
|
||||
type Migrator struct {
|
||||
@ -42,16 +30,6 @@ type Config struct {
|
||||
gorm.Dialector
|
||||
}
|
||||
|
||||
type printSQLLogger struct {
|
||||
logger.Interface
|
||||
}
|
||||
|
||||
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
fmt.Println(sql + ";")
|
||||
l.Interface.Trace(ctx, begin, fc, err)
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||
@ -94,6 +72,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
expr.SQL += " NOT NULL"
|
||||
}
|
||||
|
||||
if field.Unique {
|
||||
expr.SQL += " UNIQUE"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||
if field.DefaultValueInterface != nil {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
@ -107,40 +89,23 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
|
||||
queryTx = m.DB.Session(&gorm.Session{})
|
||||
execTx = queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
return queryTx, execTx
|
||||
}
|
||||
|
||||
// AutoMigrate auto migrate values
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
queryTx, execTx := m.GetQueryAndExecTx()
|
||||
if !queryTx.Migrator().HasTable(value) {
|
||||
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, err := m.DB.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
parseIndexes = stmt.Schema.ParseIndexes()
|
||||
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||
)
|
||||
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
var foundColumn gorm.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
@ -152,43 +117,37 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// found, smartly migrate
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
if err := tx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range parseCheckConstraints {
|
||||
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range parseIndexes {
|
||||
if !queryTx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -215,12 +174,7 @@ func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
@ -231,7 +185,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if !field.IgnoreMigration {
|
||||
createTableSQL += "? ?"
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||
createTableSQL += ","
|
||||
}
|
||||
@ -239,7 +193,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
|
||||
primaryKeys := []interface{}{}
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
@ -250,8 +204,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
defer func(value interface{}, name string) {
|
||||
if err == nil {
|
||||
err = tx.Migrator().CreateIndex(value, name)
|
||||
if errr == nil {
|
||||
errr = tx.Migrator().CreateIndex(value, name)
|
||||
}
|
||||
}(value, idx.Name)
|
||||
} else {
|
||||
@ -273,14 +227,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := constraint.Build()
|
||||
sql, vars := buildConstraint(constraint)
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
@ -288,11 +239,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
|
||||
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
@ -306,8 +252,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
createTableSQL += fmt.Sprint(tableOption)
|
||||
}
|
||||
|
||||
err = tx.Exec(createTableSQL, values...).Error
|
||||
return err
|
||||
errr = tx.Exec(createTableSQL, values...).Error
|
||||
return errr
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -373,9 +319,6 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// avoid using the same name field
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
f := stmt.Schema.LookUpField(name)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||
@ -395,10 +338,8 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
// DropColumn drop value's `name` column
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@ -410,15 +351,13 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
// AlterColumn alter value's `field` column' type based on schema definition
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
@ -430,10 +369,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
name := field
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@ -448,14 +385,12 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
// RenameColumn rename value's field name from oldName to newName
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@ -467,101 +402,69 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
||||
|
||||
// MigrateColumn migrate column
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
if field.IgnoreMigration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// found, smart migrate
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
var (
|
||||
alterColumn bool
|
||||
isSameType = fullDataType == realDataType
|
||||
)
|
||||
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
isSameType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
alterColumn := false
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
// check type
|
||||
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) {
|
||||
alterColumn = true
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if realDataType == "decimal" || realDataType == "numeric" &&
|
||||
regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
|
||||
precision, scale, ok := columnType.DecimalSize()
|
||||
if ok {
|
||||
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
|
||||
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
}
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
// not primary key & current database is non-nullable(to be nullable)
|
||||
if !field.PrimaryKey && !nullable {
|
||||
// not primary key & database is nullable
|
||||
if !field.PrimaryKey && nullable {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check unique
|
||||
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
|
||||
// not primary key
|
||||
if !field.PrimaryKey {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check default value
|
||||
if !field.PrimaryKey {
|
||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||
dv, dvNotNull := columnType.DefaultValue()
|
||||
if dvNotNull && !currentDefaultNotNull {
|
||||
// default value -> null
|
||||
if dvNotNull && field.DefaultValueInterface == nil {
|
||||
// defalut value -> null
|
||||
alterColumn = true
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
} else if !dvNotNull && field.DefaultValueInterface != nil {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if currentDefaultNotNull || dvNotNull {
|
||||
switch field.GORMDataType {
|
||||
case schema.Time:
|
||||
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
|
||||
alterColumn = true
|
||||
}
|
||||
case schema.Bool:
|
||||
v1, _ := strconv.ParseBool(dv)
|
||||
v2, _ := strconv.ParseBool(field.DefaultValue)
|
||||
alterColumn = v1 != v2
|
||||
default:
|
||||
alterColumn = dv != field.DefaultValue
|
||||
} else if dv != field.DefaultValue {
|
||||
// default value not equal
|
||||
// not both null
|
||||
if !(field.DefaultValueInterface == nil && !dvNotNull) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -574,39 +477,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn {
|
||||
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
|
||||
return err
|
||||
if alterColumn && !field.IgnoreMigration {
|
||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
unique, ok := columnType.Unique()
|
||||
if !ok || field.PrimaryKey {
|
||||
return nil // skip primary key
|
||||
}
|
||||
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// We're currently only receiving boolean values on `Unique` tag,
|
||||
// so the UniqueConstraint name is fixed
|
||||
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
if unique && !field.Unique {
|
||||
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||
}
|
||||
if !unique && field.Unique {
|
||||
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
@ -636,76 +513,47 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
// CreateView create view from Query in gorm.ViewOption.
|
||||
// Query in gorm.ViewOption is a [subquery]
|
||||
//
|
||||
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
|
||||
// q := DB.Model(&User{}).Where("age > ?", 20)
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
|
||||
//
|
||||
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
|
||||
// q := DB.Model(&User{})
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
|
||||
//
|
||||
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
|
||||
// CreateView create view
|
||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||
if option.Query == nil {
|
||||
return gorm.ErrSubQueryRequired
|
||||
}
|
||||
|
||||
sql := new(strings.Builder)
|
||||
sql.WriteString("CREATE ")
|
||||
if option.Replace {
|
||||
sql.WriteString("OR REPLACE ")
|
||||
}
|
||||
sql.WriteString("VIEW ")
|
||||
m.QuoteTo(sql, name)
|
||||
sql.WriteString(" AS ")
|
||||
|
||||
m.DB.Statement.AddVar(sql, option.Query)
|
||||
|
||||
if option.CheckOption != "" {
|
||||
sql.WriteString(" ")
|
||||
sql.WriteString(option.CheckOption)
|
||||
}
|
||||
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
|
||||
return gorm.ErrNotImplemented
|
||||
}
|
||||
|
||||
// DropView drop view
|
||||
func (m Migrator) DropView(name string) error {
|
||||
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
|
||||
return gorm.ErrNotImplemented
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||
//
|
||||
// Deprecated: use GuessConstraintInterfaceAndTable instead.
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
switch c := constraint.(type) {
|
||||
case *schema.Constraint:
|
||||
return c, nil, table
|
||||
case *schema.CheckConstraint:
|
||||
return nil, c, table
|
||||
default:
|
||||
return nil, nil, table
|
||||
}
|
||||
}
|
||||
|
||||
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
|
||||
// nolint:cyclop
|
||||
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
||||
if stmt.Schema == nil {
|
||||
return nil, stmt.Table
|
||||
return nil, nil, stmt.Table
|
||||
}
|
||||
|
||||
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
||||
if chk, ok := checkConstraints[name]; ok {
|
||||
return &chk, stmt.Table
|
||||
}
|
||||
|
||||
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
|
||||
if uni, ok := uniqueConstraints[name]; ok {
|
||||
return &uni, stmt.Table
|
||||
return nil, &chk, stmt.Table
|
||||
}
|
||||
|
||||
getTable := func(rel *schema.Relationship) string {
|
||||
@ -720,7 +568,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
||||
return constraint, getTable(rel)
|
||||
return constraint, nil, getTable(rel)
|
||||
}
|
||||
}
|
||||
|
||||
@ -728,39 +576,40 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
|
||||
for k := range checkConstraints {
|
||||
if checkConstraints[k].Field == field {
|
||||
v := checkConstraints[k]
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for k := range uniqueConstraints {
|
||||
if uniqueConstraints[k].Field == field {
|
||||
v := uniqueConstraints[k]
|
||||
return &v, stmt.Table
|
||||
return nil, &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
||||
return constraint, getTable(rel)
|
||||
return constraint, nil, getTable(rel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, stmt.Schema.Table
|
||||
return nil, nil, stmt.Schema.Table
|
||||
}
|
||||
|
||||
// CreateConstraint create constraint
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if chk != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
|
||||
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
|
||||
).Error
|
||||
}
|
||||
|
||||
if constraint != nil {
|
||||
vars := []interface{}{clause.Table{Name: table}}
|
||||
if stmt.TableExpr != nil {
|
||||
vars[0] = stmt.TableExpr
|
||||
}
|
||||
sql, values := constraint.Build()
|
||||
sql, values := buildConstraint(constraint)
|
||||
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -768,9 +617,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
// DropConstraint drop constraint
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
})
|
||||
@ -781,9 +632,11 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@ -825,9 +678,6 @@ type BuildIndexOptionsInterface interface {
|
||||
// CreateIndex create index `name`
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
@ -860,10 +710,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
// DropIndex drop index `name`
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
|
||||
@ -875,10 +723,8 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@ -936,31 +782,26 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
|
||||
}
|
||||
parsedSchemas[dep.Statement.Schema] = true
|
||||
|
||||
if !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
}
|
||||
|
||||
@ -1022,13 +863,3 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
||||
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
||||
// GetTypeAliases return database type aliases
|
||||
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableType return tableType gorm.TableType and execErr error
|
||||
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// TableType table type implements TableType interface
|
||||
type TableType struct {
|
||||
SchemaValue string
|
||||
NameValue string
|
||||
TypeValue string
|
||||
CommentValue sql.NullString
|
||||
}
|
||||
|
||||
// Schema returns the schema of the table.
|
||||
func (ct TableType) Schema() string {
|
||||
return ct.SchemaValue
|
||||
}
|
||||
|
||||
// Name returns the name of the table.
|
||||
func (ct TableType) Name() string {
|
||||
return ct.NameValue
|
||||
}
|
||||
|
||||
// Type returns the type of the table.
|
||||
func (ct TableType) Type() string {
|
||||
return ct.TypeValue
|
||||
}
|
||||
|
||||
// Comment returns the comment of current table.
|
||||
func (ct TableType) Comment() (comment string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
||||
7
model.go
7
model.go
@ -4,10 +4,9 @@ import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model or you may build your own model without it
|
||||
//
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
|
||||
154
prepare_stmt.go
154
prepare_stmt.go
@ -3,86 +3,70 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/stmt_store"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts stmt_store.Store
|
||||
Mux *sync.RWMutex
|
||||
Stmts map[string]Stmt
|
||||
PreparedSQL []string
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
|
||||
//
|
||||
// Parameters:
|
||||
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
|
||||
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
|
||||
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
|
||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
|
||||
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
|
||||
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBConn returns the underlying *sql.DB connection
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
// Close closes all prepared statements in the store
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, key := range db.Stmts.Keys() {
|
||||
db.Stmts.Delete(key)
|
||||
for _, query := range db.PreparedSQL {
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
delete(db.Stmts, query)
|
||||
go stmt.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset Deprecated use Close instead
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
// retry
|
||||
db.Mux.Lock()
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
return stmt, nil
|
||||
} else if ok {
|
||||
go stmt.Close()
|
||||
}
|
||||
|
||||
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err == nil {
|
||||
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
}
|
||||
|
||||
return db.Stmts[query], err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
@ -90,19 +74,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
@ -110,8 +81,11 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -121,8 +95,12 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
@ -136,32 +114,20 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Ping() error {
|
||||
conn, err := db.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
if tx.Tx != nil {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
@ -171,8 +137,12 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -182,8 +152,12 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
@ -196,11 +170,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Ping() error {
|
||||
conn, err := tx.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
162
scan.go
162
scan.go
@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// prepareValues prepare values slice
|
||||
@ -16,7 +15,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
@ -24,7 +23,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
@ -51,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
@ -66,45 +65,29 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
joinedNestedSchemaMap := make(map[string]interface{})
|
||||
|
||||
joinedSchemaMap := make(map[*schema.Field]interface{})
|
||||
for idx, field := range fields {
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else { // joinFields count is larger than 2 when using join
|
||||
var isNilPtrValue bool
|
||||
var relValue reflect.Value
|
||||
// does not contain raw dbname
|
||||
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||
// current reflect value
|
||||
currentReflectValue := reflectValue
|
||||
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||
for _, joinSchema := range nestedJoinSchemas {
|
||||
fullRels = append(fullRels, joinSchema.Name)
|
||||
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||
// same nested structure
|
||||
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
isNilPtrValue = true
|
||||
break
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedNestedSchemaMap[fullRelsName] = nil
|
||||
} else {
|
||||
joinSchema := joinFields[idx][0]
|
||||
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
if _, ok := joinedSchemaMap[joinSchema]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
currentReflectValue = relValue
|
||||
}
|
||||
|
||||
if !isNilPtrValue { // ignore if value is nil
|
||||
f := joinFields[idx][len(joinFields[idx])-1]
|
||||
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedSchemaMap[joinSchema] = nil
|
||||
}
|
||||
}
|
||||
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
@ -132,15 +115,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||
)
|
||||
|
||||
if len(db.Statement.ColumnMapping) > 0 {
|
||||
for i, column := range columns {
|
||||
v, ok := db.Statement.ColumnMapping[column]
|
||||
if ok {
|
||||
columns[i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
@ -189,10 +163,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
selectedColumnsMap = make(map[string]int, len(columns))
|
||||
joinFields [][2]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
@ -225,61 +200,42 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
matchedFieldCount := make(map[string]int, len(columns))
|
||||
schFieldsCount := len(sch.Fields)
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
if count, ok := matchedFieldCount[column]; ok {
|
||||
// handle duplicate fields
|
||||
for _, selectField := range sch.Fields {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
if count == 0 {
|
||||
matchedFieldCount[column]++
|
||||
if curIndex, ok := selectedColumnsMap[column]; ok {
|
||||
fields[idx] = field // handle duplicate fields
|
||||
offset := curIndex + 1
|
||||
// handle sch inconsistent with database
|
||||
// like Raw(`...`).Scan
|
||||
if schFieldsCount > offset {
|
||||
for fieldIndex, selectField := range sch.Fields[offset:] {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
||||
fields[idx] = selectField
|
||||
break
|
||||
}
|
||||
count--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
matchedFieldCount[column] = 1
|
||||
fields[idx] = field
|
||||
selectedColumnsMap[column] = idx
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||
for _, join := range db.Statement.Joins {
|
||||
if join.Alias == aliasName {
|
||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||
relFields = append(relFields, rel.Field)
|
||||
for _, name := range names[1 : subNameCount-1] {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// latest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][]*schema.Field, len(columns))
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
}
|
||||
relFields = append(relFields, field)
|
||||
joinFields[idx] = relFields
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
continue
|
||||
}
|
||||
}
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -287,24 +243,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
elem reflect.Value
|
||||
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||
)
|
||||
var elem reflect.Value
|
||||
recyclableStruct := reflect.New(reflectValueType)
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
if isArrayKind {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
} else {
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
@ -325,21 +269,20 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
elem = reflect.New(reflectValueType)
|
||||
if isPtr && db.RowsAffected > 0 {
|
||||
elem = reflect.New(reflectValueType)
|
||||
} else {
|
||||
elem = recyclableStruct
|
||||
}
|
||||
}
|
||||
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if !isPtr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if isArrayKind {
|
||||
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||
}
|
||||
} else {
|
||||
if isPtr {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -349,9 +292,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
}
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
|
||||
35
schema/check.go
Normal file
35
schema/check.go
Normal file
@ -0,0 +1,35 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
|
||||
|
||||
type Check struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]Check {
|
||||
checks := map[string]Check{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = Check{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type UserCheck struct {
|
||||
@ -21,7 +20,7 @@ func TestParseCheck(t *testing.T) {
|
||||
t.Fatalf("failed to parse user check, got error %v", err)
|
||||
}
|
||||
|
||||
results := map[string]schema.CheckConstraint{
|
||||
results := map[string]schema.Check{
|
||||
"name_checker": {
|
||||
Name: "name_checker",
|
||||
Constraint: "name <> 'jinzhu'",
|
||||
@ -54,31 +53,3 @@ func TestParseCheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUniqueConstraints(t *testing.T) {
|
||||
type UserUnique struct {
|
||||
Name1 string `gorm:"unique"`
|
||||
Name2 string `gorm:"uniqueIndex"`
|
||||
}
|
||||
|
||||
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||
}
|
||||
constraints := user.ParseUniqueConstraints()
|
||||
|
||||
results := map[string]schema.UniqueConstraint{
|
||||
"uni_user_uniques_name1": {
|
||||
Name: "uni_user_uniques_name1",
|
||||
Field: &schema.Field{Name: "Name1", Unique: true},
|
||||
},
|
||||
}
|
||||
for k, result := range results {
|
||||
v, ok := constraints[k]
|
||||
if !ok {
|
||||
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
|
||||
}
|
||||
tests.AssertObjEqual(t, result, v, "Name")
|
||||
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
|
||||
}
|
||||
}
|
||||
@ -1,66 +0,0 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||
|
||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||
checks := map[string]CheckConstraint{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
type UniqueConstraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||
|
||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||
}
|
||||
|
||||
// ParseUniqueConstraints parse schema unique constraints
|
||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||
uniques := make(map[string]UniqueConstraint)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Unique {
|
||||
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||
}
|
||||
}
|
||||
return uniques
|
||||
}
|
||||
141
schema/field.go
141
schema/field.go
@ -49,14 +49,11 @@ const (
|
||||
Bytes DataType = "bytes"
|
||||
)
|
||||
|
||||
const DefaultAutoIncrementIncrement int64 = 1
|
||||
|
||||
// Field is the representation of model schema's field
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
BindNames []string
|
||||
EmbeddedBindNames []string
|
||||
DataType DataType
|
||||
GORMDataType DataType
|
||||
PrimaryKey bool
|
||||
@ -90,16 +87,6 @@ type Field struct {
|
||||
Set func(context.Context, reflect.Value, interface{}) error
|
||||
Serializer SerializerInterface
|
||||
NewValuePool FieldNewValuePool
|
||||
|
||||
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
|
||||
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
|
||||
// It causes field unnecessarily migration.
|
||||
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
|
||||
UniqueIndex string
|
||||
}
|
||||
|
||||
func (field *Field) BindName() string {
|
||||
return strings.Join(field.BindNames, ".")
|
||||
}
|
||||
|
||||
// ParseField parses reflect.StructField to Field
|
||||
@ -113,7 +100,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
Name: fieldStruct.Name,
|
||||
DBName: tagSetting["COLUMN"],
|
||||
BindNames: []string{fieldStruct.Name},
|
||||
EmbeddedBindNames: []string{fieldStruct.Name},
|
||||
FieldType: fieldStruct.Type,
|
||||
IndirectFieldType: fieldStruct.Type,
|
||||
StructField: fieldStruct,
|
||||
@ -129,7 +115,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||
Comment: tagSetting["COMMENT"],
|
||||
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||
AutoIncrementIncrement: 1,
|
||||
}
|
||||
|
||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||
@ -188,7 +174,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
field.DataType = String
|
||||
field.Serializer = v
|
||||
} else {
|
||||
serializerName := field.TagSettings["JSON"]
|
||||
var serializerName = field.TagSettings["JSON"]
|
||||
if serializerName == "" {
|
||||
serializerName = field.TagSettings["SERIALIZER"]
|
||||
}
|
||||
@ -318,10 +304,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
lowerVal := DataType(strings.ToLower(val))
|
||||
switch lowerVal {
|
||||
switch DataType(strings.ToLower(val)) {
|
||||
case Bool, Int, Uint, Float, String, Time, Bytes:
|
||||
field.DataType = lowerVal
|
||||
field.DataType = DataType(strings.ToLower(val))
|
||||
default:
|
||||
field.DataType = DataType(val)
|
||||
}
|
||||
@ -406,9 +391,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
|
||||
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
|
||||
}
|
||||
// index is negative means is pointer
|
||||
if field.FieldType.Kind() == reflect.Struct {
|
||||
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
|
||||
@ -448,30 +430,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
// create valuer, setter when parse struct
|
||||
func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
func (field *Field) setupValuerAndSetter() {
|
||||
// Setup NewValuePool
|
||||
field.setupNewValuePool()
|
||||
|
||||
// ValueOf returns field's value and if it is zero
|
||||
fieldIndex := field.StructField.Index[0]
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
fieldValue := v.FieldByName(field.Name)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
fieldValue := v.Field(fieldIndex)
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
default:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
fieldValue := v.FieldByName(field.Name)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
for _, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
@ -513,20 +486,13 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
|
||||
// ReflectValueOf returns field's reflect value
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
return v.FieldByName(field.Name)
|
||||
}
|
||||
return v.Field(fieldIndex)
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||
return reflect.Indirect(value).Field(fieldIndex)
|
||||
}
|
||||
default:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
return v.FieldByName(field.Name)
|
||||
}
|
||||
for idx, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
@ -614,6 +580,8 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
case **bool:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetBool(**data)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetBool(false)
|
||||
}
|
||||
case bool:
|
||||
field.ReflectValueOf(ctx, value).SetBool(data)
|
||||
@ -633,22 +601,8 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
case **int64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||
}
|
||||
case **int:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(0)
|
||||
}
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||
@ -686,7 +640,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -695,7 +649,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -713,22 +667,8 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
case **uint64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||
}
|
||||
case **uint:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(0)
|
||||
}
|
||||
case uint64:
|
||||
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||
@ -760,7 +700,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
@ -781,10 +721,8 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
case **float64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||
}
|
||||
case **float32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(0)
|
||||
}
|
||||
case float64:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||
@ -829,6 +767,8 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
case **string:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetString(**data)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetString("")
|
||||
}
|
||||
case string:
|
||||
field.ReflectValueOf(ctx, value).SetString(data)
|
||||
@ -876,7 +816,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **time.Time:
|
||||
if data != nil && *data != nil {
|
||||
if data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||
}
|
||||
case time.Time:
|
||||
@ -912,12 +852,14 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
@ -938,12 +880,14 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
@ -972,8 +916,6 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||
}
|
||||
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
if s, ok := v.(*serializer); ok {
|
||||
if s.fieldValue != nil {
|
||||
@ -981,12 +923,11 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||
if sameElemType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||
} else if sameType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||
s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface)
|
||||
}
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
s.Serializer = si.Interface().(SerializerInterface)
|
||||
}
|
||||
} else {
|
||||
err = oldFieldSetter(ctx, value, v)
|
||||
@ -998,21 +939,17 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
|
||||
func (field *Field) setupNewValuePool() {
|
||||
if field.Serializer != nil {
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.NewValuePool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
return &serializer{
|
||||
Field: field,
|
||||
Serializer: si.Interface().(SerializerInterface),
|
||||
Serializer: field.Serializer,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if field.NewValuePool == nil {
|
||||
field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType))
|
||||
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,8 +13,8 @@ type Index struct {
|
||||
Type string // btree, hash, gist, spgist, gin, and brin
|
||||
Where string
|
||||
Comment string
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption
|
||||
}
|
||||
|
||||
type IndexOption struct {
|
||||
@ -23,13 +23,12 @@ type IndexOption struct {
|
||||
Sort string // DESC, ASC
|
||||
Collate string
|
||||
Length int
|
||||
Priority int
|
||||
priority int
|
||||
}
|
||||
|
||||
// ParseIndexes parse schema indexes
|
||||
func (schema *Schema) ParseIndexes() []*Index {
|
||||
indexesByName := map[string]*Index{}
|
||||
indexes := []*Index{}
|
||||
func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
indexes := map[string]Index{}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
@ -39,12 +38,7 @@ func (schema *Schema) ParseIndexes() []*Index {
|
||||
break
|
||||
}
|
||||
for _, index := range fieldIndexes {
|
||||
idx := indexesByName[index.Name]
|
||||
if idx == nil {
|
||||
idx = &Index{Name: index.Name}
|
||||
indexesByName[index.Name] = idx
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
idx := indexes[index.Name]
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
@ -64,16 +58,14 @@ func (schema *Schema) ParseIndexes() []*Index {
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
return idx.Fields[i].Priority < idx.Fields[j].Priority
|
||||
return idx.Fields[i].priority < idx.Fields[j].priority
|
||||
})
|
||||
|
||||
indexes[index.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.UniqueIndex = index.Name
|
||||
}
|
||||
}
|
||||
|
||||
return indexes
|
||||
}
|
||||
|
||||
@ -82,12 +74,12 @@ func (schema *Schema) LookIndex(name string) *Index {
|
||||
indexes := schema.ParseIndexes()
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return index
|
||||
return &index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return index
|
||||
return &index
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -105,7 +97,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.IndexByte(tag, ',')
|
||||
idx = strings.Index(tag, ",")
|
||||
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
|
||||
settings = ParseTagSetting(tagSetting, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
@ -115,14 +107,17 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
name = tag[0:idx]
|
||||
if idx != -1 {
|
||||
name = tag[0:idx]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
subName := field.Name
|
||||
const key = "COMPOSITE"
|
||||
if composite, found := settings[key]; found {
|
||||
if len(composite) == 0 || composite == key {
|
||||
err = fmt.Errorf(
|
||||
"the composite tag of %s.%s cannot be empty",
|
||||
"The composite tag of %s.%s cannot be empty",
|
||||
field.Schema.Name,
|
||||
field.Name)
|
||||
return
|
||||
@ -155,7 +150,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
Priority: priority,
|
||||
priority: priority,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
package schema_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
type UserIndex struct {
|
||||
@ -19,10 +19,6 @@ type UserIndex struct {
|
||||
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
|
||||
MemberNumber string `gorm:"index:idx_id,priority:1"`
|
||||
Name7 string `gorm:"index:type"`
|
||||
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
|
||||
|
||||
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
|
||||
CompName2 string `gorm:"index:,composite:idx_compname_1"`
|
||||
|
||||
// Composite Index: Flattened structure.
|
||||
Data0A string `gorm:"index:,composite:comp_id0"`
|
||||
@ -61,17 +57,17 @@ func TestParseIndex(t *testing.T) {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
|
||||
results := []*schema.Index{
|
||||
{
|
||||
results := map[string]schema.Index{
|
||||
"idx_user_indices_name": {
|
||||
Name: "idx_user_indices_name",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
|
||||
},
|
||||
{
|
||||
"idx_name": {
|
||||
Name: "idx_name",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_name3": {
|
||||
Name: "idx_user_indices_name3",
|
||||
Type: "btree",
|
||||
Where: "name3 != 'jinzhu'",
|
||||
@ -82,19 +78,19 @@ func TestParseIndex(t *testing.T) {
|
||||
Length: 10,
|
||||
}},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_name4": {
|
||||
Name: "idx_user_indices_name4",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_name5": {
|
||||
Name: "idx_user_indices_name5",
|
||||
Class: "FULLTEXT",
|
||||
Comment: "hello , world",
|
||||
Where: "age > 10",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
|
||||
},
|
||||
{
|
||||
"profile": {
|
||||
Name: "profile",
|
||||
Comment: "hello , world",
|
||||
Where: "age > 10",
|
||||
@ -104,39 +100,21 @@ func TestParseIndex(t *testing.T) {
|
||||
Expression: "ABS(age)",
|
||||
}},
|
||||
},
|
||||
{
|
||||
"idx_id": {
|
||||
Name: "idx_id",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
|
||||
},
|
||||
{
|
||||
"idx_oid": {
|
||||
Name: "idx_oid",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
|
||||
},
|
||||
{
|
||||
"type": {
|
||||
Name: "type",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
||||
},
|
||||
{
|
||||
Name: "idx_user_indices_name8",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "Name8"}, Length: 10},
|
||||
// Note: Duplicate Columns
|
||||
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Class: "UNIQUE",
|
||||
Name: "idx_user_indices_idx_compname_1",
|
||||
Option: "NULLS NOT DISTINCT",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
|
||||
{Field: &schema.Field{Name: "CompName2"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_comp_id0": {
|
||||
Name: "idx_user_indices_comp_id0",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{{
|
||||
@ -145,7 +123,7 @@ func TestParseIndex(t *testing.T) {
|
||||
Field: &schema.Field{Name: "Data0B"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_comp_id1": {
|
||||
Name: "idx_user_indices_comp_id1",
|
||||
Fields: []schema.IndexOption{{
|
||||
Field: &schema.Field{Name: "Data1A"},
|
||||
@ -155,7 +133,7 @@ func TestParseIndex(t *testing.T) {
|
||||
Field: &schema.Field{Name: "Data1C"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
"idx_user_indices_comp_id2": {
|
||||
Name: "idx_user_indices_comp_id2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{
|
||||
@ -168,108 +146,37 @@ func TestParseIndex(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
CheckIndices(t, results, user.ParseIndexes())
|
||||
}
|
||||
indices := user.ParseIndexes()
|
||||
|
||||
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
type IndexTest struct {
|
||||
FieldA string `gorm:"unique;index"` // unique and index
|
||||
FieldB string `gorm:"unique"` // unique
|
||||
for k, result := range results {
|
||||
v, ok := indices[k]
|
||||
if !ok {
|
||||
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
|
||||
}
|
||||
|
||||
FieldC string `gorm:"index:,unique"` // uniqueIndex
|
||||
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
|
||||
|
||||
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
|
||||
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
|
||||
|
||||
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
|
||||
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
|
||||
|
||||
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
|
||||
|
||||
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
|
||||
}
|
||||
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
indices := indexSchema.ParseIndexes()
|
||||
expectedIndices := []*schema.Index{
|
||||
{
|
||||
Name: "idx_index_tests_field_a",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
|
||||
},
|
||||
{
|
||||
Name: "idx_index_tests_field_c",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
|
||||
},
|
||||
{
|
||||
Name: "idx_index_tests_field_d",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldD"}},
|
||||
// Note: Duplicate Columns
|
||||
{Field: &schema.Field{Name: "FieldD"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "uniq_field_e1_e2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldE1"}},
|
||||
{Field: &schema.Field{Name: "FieldE2"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "uniq_field_f1_f2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldF1"}},
|
||||
{Field: &schema.Field{Name: "FieldF2"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "idx_index_tests_field_f1",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
|
||||
},
|
||||
{
|
||||
Name: "idx_index_tests_field_g",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
|
||||
},
|
||||
{
|
||||
Name: "uniq_field_h1_h2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
|
||||
{Field: &schema.Field{Name: "FieldH2"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
CheckIndices(t, expectedIndices, indices)
|
||||
}
|
||||
|
||||
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
|
||||
if len(expected) != len(actual) {
|
||||
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
|
||||
return
|
||||
}
|
||||
|
||||
for i, ei := range expected {
|
||||
t.Run(ei.Name, func(t *testing.T) {
|
||||
ai := actual[i]
|
||||
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
|
||||
|
||||
if len(ei.Fields) != len(ai.Fields) {
|
||||
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
|
||||
return
|
||||
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
|
||||
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
|
||||
t.Errorf(
|
||||
"index %v %v should equal, expects %v, got %v",
|
||||
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
|
||||
)
|
||||
}
|
||||
for i, ef := range ei.Fields {
|
||||
af := ai.Fields[i]
|
||||
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
|
||||
}
|
||||
|
||||
for idx, ef := range result.Fields {
|
||||
rf := v.Fields[idx]
|
||||
if rf.Field.Name != ef.Field.Name {
|
||||
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
|
||||
}
|
||||
})
|
||||
|
||||
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
|
||||
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
|
||||
t.Errorf(
|
||||
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
|
||||
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,12 +4,6 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConstraintInterface database constraint interface
|
||||
type ConstraintInterface interface {
|
||||
GetName() string
|
||||
Build() (sql string, vars []interface{})
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
|
||||
@ -8,8 +8,6 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Namer namer interface
|
||||
@ -21,7 +19,6 @@ type Namer interface {
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
UniqueName(table, column string) string
|
||||
}
|
||||
|
||||
// Replacer replacer interface like strings.Replacer
|
||||
@ -29,15 +26,12 @@ type Replacer interface {
|
||||
Replace(name string) string
|
||||
}
|
||||
|
||||
var _ Namer = (*NamingStrategy)(nil)
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
IdentifierMaxLength int
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
@ -90,26 +84,17 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
return ns.formatName("idx", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
// UniqueName generate unique constraint name
|
||||
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||
return ns.formatName("uni", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_")
|
||||
|
||||
if ns.IdentifierMaxLength == 0 {
|
||||
ns.IdentifierMaxLength = 64
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||
if utf8.RuneCountInString(formattedName) > 64 {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
@ -123,7 +108,7 @@ var (
|
||||
func init() {
|
||||
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
@ -188,9 +173,9 @@ func (ns NamingStrategy) toDBName(name string) string {
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
for _, initialism := range commonInitialisms {
|
||||
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@ -189,17 +189,8 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
|
||||
ns := NamingStrategy{IdentifierMaxLength: 63}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
|
||||
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
||||
ns := NamingStrategy{IdentifierMaxLength: 64}
|
||||
ns := NamingStrategy{}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||
|
||||
@ -5,12 +5,8 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
@ -31,10 +27,6 @@ type Relationships struct {
|
||||
HasMany []*Relationship
|
||||
Many2Many []*Relationship
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
|
||||
Mux sync.RWMutex
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
@ -78,12 +70,12 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
cacheStore := schema.cacheStore
|
||||
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
|
||||
schema.err = err
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
@ -95,16 +87,14 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if relation.Type == has {
|
||||
// don't add relations to embedded schema, which might be shared
|
||||
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
|
||||
relation.FieldSchema.Relationships.Mux.Lock()
|
||||
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
|
||||
relation.FieldSchema.Relationships.Mux.Unlock()
|
||||
}
|
||||
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
@ -116,7 +106,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
schema.setRelation(relation)
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
switch relation.Type {
|
||||
case HasOne:
|
||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||
@ -132,100 +122,34 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return relation
|
||||
}
|
||||
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
if len(rel.Field.BindNames) > 1 {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
} else {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.EmbeddedBindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.EmbeddedBindNames {
|
||||
if i < len(relation.Field.EmbeddedBindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||
}
|
||||
relationships = relationships.EmbeddedRelations[name]
|
||||
} else {
|
||||
if relationships.Relations == nil {
|
||||
relationships.Relations = map[string]*Relationship{}
|
||||
}
|
||||
relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
//
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
Value: schema.Table,
|
||||
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
|
||||
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
@ -237,17 +161,10 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryKeyField.DataType) {
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
@ -308,9 +225,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
for idx, ownField := range ownForeignFields {
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
|
||||
joinFieldName := strings.Title(schema.Name) + ownField.Name
|
||||
if len(joinForeignKeys) > idx {
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
|
||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
@ -325,7 +242,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
|
||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
@ -336,7 +253,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
@ -354,13 +271,12 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
|
||||
Name: strings.Title(schema.Name) + field.Name,
|
||||
Type: schema.ModelType,
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
@ -479,8 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -488,31 +403,34 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema == nil {
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema == nil {
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
primarySchemaName := primarySchema.Name
|
||||
var primaryFields []*Field
|
||||
var primarySchemaName = primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
}
|
||||
@ -527,7 +445,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
primaryFields = primarySchema.PrimaryFields
|
||||
}
|
||||
|
||||
primaryFieldLoop:
|
||||
for _, primaryField := range primaryFields {
|
||||
lookUpName := primarySchemaName + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
@ -536,33 +453,23 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpField(name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
if len(foreignFields) == 0 {
|
||||
reguessOrErr()
|
||||
return
|
||||
case len(relation.primaryKeys) > 0:
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
@ -576,7 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
return
|
||||
}
|
||||
}
|
||||
case len(primaryFields) == 0:
|
||||
} else if len(primaryFields) == 0 {
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
@ -612,7 +519,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is ForeignKey Constraint
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
@ -624,31 +530,6 @@ type Constraint struct {
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||
|
||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
references := make([]interface{}, 0, len(constraint.References))
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||
if str == "-" {
|
||||
@ -663,7 +544,6 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
|
||||
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -676,7 +556,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
|
||||
var (
|
||||
name string
|
||||
idx = strings.IndexByte(str, ',')
|
||||
idx = strings.Index(str, ",")
|
||||
settings = ParseTagSetting(str, ",")
|
||||
)
|
||||
|
||||
@ -763,9 +643,8 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
|
||||
}
|
||||
|
||||
func copyableDataType(str DataType) bool {
|
||||
lowerStr := strings.ToLower(string(str))
|
||||
for _, s := range []string{"auto_increment", "primary key"} {
|
||||
if strings.Contains(lowerStr, s) {
|
||||
if strings.Contains(strings.ToLower(string(str)), s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,29 +121,6 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBelongsToWithMixin(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Name string
|
||||
}
|
||||
|
||||
type ProfileMixin struct {
|
||||
Profile Profile `gorm:"References:Refer"`
|
||||
ProfileRefer int
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
ProfileMixin
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
|
||||
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
@ -541,320 +518,6 @@ func TestEmbeddedRelation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedHas(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
type User struct {
|
||||
ID int
|
||||
Cat struct {
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||
Dog struct {
|
||||
ID int
|
||||
Name string
|
||||
UserID int
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolymorphic(t *testing.T) {
|
||||
t.Run("has one", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has one with only polymorphic type", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
RefId int
|
||||
Type string
|
||||
}
|
||||
|
||||
type Cat struct {
|
||||
ID int
|
||||
Name string
|
||||
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
type Address struct {
|
||||
CountryID int
|
||||
Country Country
|
||||
}
|
||||
type NestedAddress struct {
|
||||
Address
|
||||
}
|
||||
type CountryMixin struct {
|
||||
CountryID int
|
||||
Country Country
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||
AddressID int
|
||||
Address struct {
|
||||
ID int
|
||||
Address
|
||||
}
|
||||
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
CountryMixin
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"PostalAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"VisitingAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"NestedAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestVariableRelation(t *testing.T) {
|
||||
var result struct {
|
||||
User
|
||||
@ -979,7 +642,7 @@ func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
||||
s, err := schema.Parse(
|
||||
&Book{},
|
||||
&sync.Map{},
|
||||
schema.NamingStrategy{IdentifierMaxLength: 64},
|
||||
schema.NamingStrategy{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema")
|
||||
|
||||
150
schema/schema.go
150
schema/schema.go
@ -5,29 +5,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type callbackType string
|
||||
|
||||
const (
|
||||
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||
callbackTypeAfterSave callbackType = "AfterSave"
|
||||
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||
callbackTypeAfterFind callbackType = "AfterFind"
|
||||
)
|
||||
|
||||
// ErrUnsupportedDataType unsupported data type
|
||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||
|
||||
@ -41,7 +25,6 @@ type Schema struct {
|
||||
PrimaryFieldDBNames []string
|
||||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
@ -68,10 +51,9 @@ func (schema Schema) String() string {
|
||||
}
|
||||
|
||||
func (schema Schema) MakeSlice() reflect.Value {
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
|
||||
results := reflect.New(slice.Type())
|
||||
results.Elem().Set(slice)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
@ -85,35 +67,10 @@ func (schema Schema) LookUpField(name string) *Field {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||
//
|
||||
// type Struct struct {
|
||||
// Embedded struct {
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||
// }
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||
// }
|
||||
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||
if len(bindNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
@ -168,9 +125,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
@ -179,16 +133,15 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
@ -216,7 +169,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
bindName := field.BindName()
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
@ -225,7 +177,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
schema.FieldsByDBName[field.DBName] = field
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
@ -244,11 +195,8 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
}
|
||||
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter(modelType)
|
||||
field.setupValuerAndSetter()
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.LookUpField("id")
|
||||
@ -266,18 +214,8 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField == nil {
|
||||
if len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
} else if len(schema.PrimaryFields) > 1 {
|
||||
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||
for _, field := range schema.PrimaryFields {
|
||||
if field.AutoIncrement {
|
||||
schema.PrioritizedPrimaryField = field
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
}
|
||||
|
||||
for _, field := range schema.PrimaryFields {
|
||||
@ -285,7 +223,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
if field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
@ -304,26 +242,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
}
|
||||
}
|
||||
|
||||
callbackTypes := []callbackType{
|
||||
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||
callbackTypeAfterFind,
|
||||
}
|
||||
for _, cbName := range callbackTypes {
|
||||
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
|
||||
for _, name := range callbacks {
|
||||
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB) error":
|
||||
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
|
||||
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
} else {
|
||||
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
|
||||
// PASS
|
||||
}
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -345,12 +271,11 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
} else {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[field.BindName()] = field
|
||||
}
|
||||
}
|
||||
|
||||
@ -377,39 +302,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
return schema, schema.err
|
||||
}
|
||||
|
||||
// This unrolling is needed to show to the compiler the exact set of methods
|
||||
// that can be used on the modelType.
|
||||
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||
// abandon dead code elimination for the entire binary.
|
||||
// As of go1.22 the compiler supports one special case of a string constant
|
||||
// being passed to MethodByName. For enterprise customers or those building
|
||||
// large binaries, this gives a significant reduction in binary size.
|
||||
// https://github.com/golang/go/issues/62257
|
||||
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||
switch cbType {
|
||||
case callbackTypeBeforeCreate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||
case callbackTypeAfterCreate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||
case callbackTypeBeforeUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||
case callbackTypeAfterUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||
case callbackTypeBeforeSave:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||
case callbackTypeAfterSave:
|
||||
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||
case callbackTypeBeforeDelete:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||
case callbackTypeAfterDelete:
|
||||
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||
case callbackTypeAfterFind:
|
||||
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||
default:
|
||||
return reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
|
||||
@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
|
||||
}
|
||||
|
||||
for i := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
checkSchemaField(t, r.JoinTable, &f, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,37 +201,6 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
||||
})
|
||||
}
|
||||
|
||||
type EmbeddedRelations struct {
|
||||
Relations map[string]Relation
|
||||
EmbeddedRelations map[string]EmbeddedRelations
|
||||
}
|
||||
|
||||
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||
for name, relations := range actual {
|
||||
rs := expected[name]
|
||||
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||
if len(relations.Relations) != len(rs.Relations) {
|
||||
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||
}
|
||||
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||
}
|
||||
for n, rel := range relations.Relations {
|
||||
if r, ok := rs.Relations[n]; !ok {
|
||||
t.Errorf("failed to find relation by name %s", n)
|
||||
} else {
|
||||
checkSchemaRelation(t, &schema.Schema{
|
||||
Relationships: schema.Relationships{
|
||||
Relations: map[string]*schema.Relationship{n: rel},
|
||||
},
|
||||
}, r)
|
||||
}
|
||||
}
|
||||
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||
for k, v := range values {
|
||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||
|
||||
@ -19,22 +19,6 @@ func TestParseSchema(t *testing.T) {
|
||||
checkUserSchema(t, user)
|
||||
}
|
||||
|
||||
func TestParseSchemaWithMap(t *testing.T) {
|
||||
type User struct {
|
||||
tests.User
|
||||
Attrs map[string]string `gorm:"type:Map(String,String);"`
|
||||
}
|
||||
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user with map, got error %v", err)
|
||||
}
|
||||
|
||||
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
|
||||
t.Errorf("failed to parse user field Attrs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
@ -62,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
|
||||
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
@ -152,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
|
||||
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
|
||||
}
|
||||
|
||||
for i := range fields {
|
||||
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, user, &f, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
@ -309,44 +293,3 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
|
||||
type Product struct {
|
||||
ProductID uint `gorm:"primaryKey;autoIncrement"`
|
||||
LanguageCode uint `gorm:"primaryKey"`
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
type ProductNonAutoIncrement struct {
|
||||
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
|
||||
LanguageCode uint `gorm:"primaryKey"`
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
|
||||
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.Field{
|
||||
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
|
||||
}
|
||||
|
||||
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
|
||||
|
||||
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
})
|
||||
|
||||
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
|
||||
}
|
||||
|
||||
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
|
||||
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,7 +70,8 @@ type SerializerValuerInterface interface {
|
||||
}
|
||||
|
||||
// JSONSerializer json serializer
|
||||
type JSONSerializer struct{}
|
||||
type JSONSerializer struct {
|
||||
}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -84,10 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
bytes, err = json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||
}
|
||||
|
||||
if len(bytes) > 0 {
|
||||
@ -102,17 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
// UnixSecondSerializer json serializer
|
||||
type UnixSecondSerializer struct{}
|
||||
type UnixSecondSerializer struct {
|
||||
}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -129,12 +122,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
@ -142,7 +135,8 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
}
|
||||
|
||||
// GobSerializer gob serializer
|
||||
type GobSerializer struct{}
|
||||
type GobSerializer struct {
|
||||
}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
|
||||
@ -71,7 +71,7 @@ func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
|
||||
appendToResults := func(value reflect.Value) {
|
||||
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
|
||||
@ -115,11 +115,6 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
||||
notZero, zero bool
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Ptr ||
|
||||
reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
@ -138,7 +133,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
elem := reflectValue.Index(i)
|
||||
elemKey := elem.Interface()
|
||||
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
|
||||
if elem.Kind() != reflect.Ptr {
|
||||
elemKey = elem.Addr().Interface()
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
@ -46,21 +45,11 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||
if _, err := now.Parse(v); err == nil {
|
||||
return sql.NullString{String: v, Valid: true}
|
||||
}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f}}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Name() string {
|
||||
@ -89,19 +78,18 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
|
||||
}
|
||||
|
||||
type SoftDeleteUpdateClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Name() string {
|
||||
@ -121,12 +109,11 @@ func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Name() string {
|
||||
|
||||
162
statement.go
162
statement.go
@ -30,9 +30,8 @@ type Statement struct {
|
||||
Clauses map[string]clause.Clause
|
||||
BuildClauses []string
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
ColumnMapping map[string]string // map columns
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Joins []join
|
||||
Preloads map[string][]interface{}
|
||||
Settings sync.Map
|
||||
@ -47,18 +46,12 @@ type Statement struct {
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
scopes []func(*DB) *DB
|
||||
Result *result
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Alias string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
Expression clause.Expression
|
||||
JoinType clause.JoinType
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
@ -124,8 +117,6 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
write(v.Raw, stmt.Schema.DBNames[0])
|
||||
} else {
|
||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
@ -188,10 +179,6 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
@ -208,21 +195,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case interface{ getInstance() *DB }:
|
||||
cv := v.getInstance()
|
||||
|
||||
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if cv.Statement.SQL.Len() > 0 {
|
||||
case *DB:
|
||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if v.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = cv.Statement.SQL.String()
|
||||
sql = v.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
cv.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
@ -319,31 +304,23 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for idx, arg := range args {
|
||||
if arg == nil {
|
||||
continue
|
||||
}
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
curTable := stmt.Table
|
||||
if curTable == "" {
|
||||
curTable = clause.CurrentTable
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
v.executeScopes()
|
||||
for _, scope := range v.Statement.scopes {
|
||||
v = scope(v)
|
||||
}
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
if len(orConds.Exprs) == 1 {
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
}
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
@ -363,11 +340,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
if strings.Contains(key, ".") {
|
||||
column = clause.Column{Name: key}
|
||||
}
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
keys := make([]string, 0, len(v))
|
||||
@ -378,16 +351,12 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
for _, key := range keys {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
if strings.Contains(key, ".") {
|
||||
column = clause.Column{Name: key}
|
||||
}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else {
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
@ -396,10 +365,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -426,9 +395,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -440,9 +409,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -467,22 +436,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
}
|
||||
return nil
|
||||
return conds
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
return conds
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
@ -534,14 +499,12 @@ func (stmt *Statement) clone() *Statement {
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
ColumnMapping: stmt.ColumnMapping,
|
||||
Preloads: map[string][]interface{}{},
|
||||
ConnPool: stmt.ConnPool,
|
||||
Schema: stmt.Schema,
|
||||
Context: stmt.Context,
|
||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||
SkipHooks: stmt.SkipHooks,
|
||||
Result: stmt.Result,
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
@ -577,9 +540,8 @@ func (stmt *Statement) clone() *Statement {
|
||||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
@ -688,62 +650,54 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
processColumn := func(column string, result bool) {
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
if stmt.Schema == nil {
|
||||
results[column] = result
|
||||
results[column] = true
|
||||
} else if column == "*" {
|
||||
notRestricted = result
|
||||
notRestricted = true
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
results[dbName] = true
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = result
|
||||
results[rel.Name] = true
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = result
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[col] = result
|
||||
}
|
||||
results[field.DBName] = true
|
||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
|
||||
results[matches[2]] = true
|
||||
} else {
|
||||
results[column] = result
|
||||
results[column] = true
|
||||
}
|
||||
}
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn(column, true)
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, column := range stmt.Omits {
|
||||
processColumn(column, false)
|
||||
for _, omit := range stmt.Omits {
|
||||
if stmt.Schema == nil {
|
||||
results[omit] = false
|
||||
} else if omit == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = false
|
||||
}
|
||||
} else if omit == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = false
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
|
||||
results[matches[1]] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
|
||||
@ -35,13 +35,6 @@ func TestWhereCloneCorruption(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilCondition(t *testing.T) {
|
||||
s := new(Statement)
|
||||
if len(s.BuildCondition(nil)) != 0 {
|
||||
t.Errorf("Nil condition should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameMatcher(t *testing.T) {
|
||||
for k, v := range map[string][]string{
|
||||
"table.name": {"table", "name"},
|
||||
@ -56,15 +49,9 @@ func TestNameMatcher(t *testing.T) {
|
||||
"`name_1`": {"", "name_1"},
|
||||
"`Name_1`": {"", "Name_1"},
|
||||
"`Table`.`nAme`": {"Table", "nAme"},
|
||||
"my_table.*": {"my_table", "*"},
|
||||
"`my_table`.*": {"my_table", "*"},
|
||||
"User__Company.*": {"User__Company", "*"},
|
||||
"`User__Company`.*": {"User__Company", "*"},
|
||||
`"User__Company".*`: {"User__Company", "*"},
|
||||
`"table"."*"`: {"", ""},
|
||||
} {
|
||||
if table, column := matchName(k); table != v[0] || column != v[1] {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
|
||||
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
|
||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,6 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -138,7 +137,6 @@ func TestBelongsToAssociation(t *testing.T) {
|
||||
unexistCompanyID := company.ID + 9999999
|
||||
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
|
||||
if err := DB.Create(&user).Error; err == nil {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
t.Errorf("should have gotten foreign key violation error")
|
||||
}
|
||||
}
|
||||
@ -226,81 +224,3 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
|
||||
AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
|
||||
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
|
||||
}
|
||||
|
||||
func TestBelongsToDefaultValue(t *testing.T) {
|
||||
type Org struct {
|
||||
ID string
|
||||
}
|
||||
type BelongsToUser struct {
|
||||
OrgID string
|
||||
Org Org `gorm:"default:NULL"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
|
||||
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
|
||||
|
||||
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
|
||||
tx.AutoMigrate(&BelongsToUser{}, &Org{})
|
||||
|
||||
user := &BelongsToUser{
|
||||
Org: Org{
|
||||
ID: "BelongsToUser_Org_1",
|
||||
},
|
||||
}
|
||||
err := DB.Create(&user).Error
|
||||
AssertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
func TestBelongsToAssociationUnscoped(t *testing.T) {
|
||||
type ItemParent struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
}
|
||||
type ItemChild struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"type:varchar(50)"`
|
||||
ItemParentID uint
|
||||
ItemParent ItemParent
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
|
||||
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
|
||||
|
||||
item := ItemChild{
|
||||
Name: "name",
|
||||
ItemParent: ItemParent{
|
||||
Logo: "logo",
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
// test replace
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
|
||||
Logo: "updated logo",
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item parent, got error: %v", err)
|
||||
}
|
||||
|
||||
var parents []ItemParent
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 1 {
|
||||
t.Errorf("expected %d parents, got %d", 1, len(parents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
|
||||
t.Errorf("failed to delete item parent, got error: %v", err)
|
||||
}
|
||||
if err := tx.Find(&parents).Error; err != nil {
|
||||
t.Errorf("failed to find item parent, got error: %v", err)
|
||||
}
|
||||
if len(parents) != 0 {
|
||||
t.Errorf("expected %d parents, got %d", 0, len(parents))
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,7 +3,6 @@ package tests_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -422,7 +421,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
|
||||
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
users := []User{
|
||||
*GetUser("slice-hasmany-1", Config{Toys: 2}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
|
||||
*GetUser("slice-hasmany-2", Config{Toys: 0}),
|
||||
*GetUser("slice-hasmany-3", Config{Toys: 4}),
|
||||
}
|
||||
|
||||
@ -430,7 +429,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
|
||||
// Count
|
||||
AssertAssociationCount(t, users, "Toys", 6, "")
|
||||
AssertAssociationCount(t, users, "Tools", 2, "")
|
||||
|
||||
// Find
|
||||
var toys []Toy
|
||||
@ -438,14 +436,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
|
||||
}
|
||||
|
||||
// Find Tools (polymorphic with custom type and id)
|
||||
var tools []Tools
|
||||
DB.Model(&users).Association("Tools").Find(&tools)
|
||||
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
|
||||
}
|
||||
|
||||
// Append
|
||||
DB.Model(&users).Association("Toys").Append(
|
||||
&Toy{Name: "toy-slice-append-1"},
|
||||
@ -481,88 +471,3 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
|
||||
DB.Model(&users).Association("Toys").Clear()
|
||||
AssertAssociationCount(t, users, "Toys", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestHasManyAssociationUnscoped(t *testing.T) {
|
||||
type ItemContent struct {
|
||||
gorm.Model
|
||||
ItemID uint `gorm:"not null"`
|
||||
Name string `gorm:"not null;type:varchar(50)"`
|
||||
LanguageCode string `gorm:"not null;type:varchar(2)"`
|
||||
}
|
||||
type Item struct {
|
||||
gorm.Model
|
||||
Logo string `gorm:"not null;type:varchar(50)"`
|
||||
Contents []ItemContent `gorm:"foreignKey:ItemID"`
|
||||
}
|
||||
|
||||
tx := DB.Session(&gorm.Session{})
|
||||
tx.Migrator().DropTable(&ItemContent{}, &Item{})
|
||||
tx.AutoMigrate(&ItemContent{}, &Item{})
|
||||
|
||||
item := Item{
|
||||
Logo: "logo",
|
||||
Contents: []ItemContent{
|
||||
{Name: "name", LanguageCode: "en"},
|
||||
{Name: "ar name", LanguageCode: "ar"},
|
||||
},
|
||||
}
|
||||
if err := tx.Create(&item).Error; err != nil {
|
||||
t.Fatalf("failed to create items, got error: %v", err)
|
||||
}
|
||||
|
||||
// test Replace
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
|
||||
{Name: "updated name", LanguageCode: "en"},
|
||||
{Name: "ar updated name", LanguageCode: "ar"},
|
||||
{Name: "le nom", LanguageCode: "fr"},
|
||||
}); err != nil {
|
||||
t.Errorf("failed to replace item content, got error: %v", err)
|
||||
}
|
||||
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, count)
|
||||
}
|
||||
|
||||
var contents []ItemContent
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 3 {
|
||||
t.Errorf("expected %d contents, got %d", 3, len(contents))
|
||||
}
|
||||
|
||||
// test delete
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
|
||||
t.Errorf("failed to delete Contents, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
|
||||
t.Errorf("expected %d contents, got %d", 2, count)
|
||||
}
|
||||
|
||||
// test clear
|
||||
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
|
||||
t.Errorf("failed to clear contents association, got error: %v", err)
|
||||
}
|
||||
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, count)
|
||||
}
|
||||
|
||||
if err := tx.Find(&contents).Error; err != nil {
|
||||
t.Errorf("failed to find contents, got error: %v", err)
|
||||
}
|
||||
if len(contents) != 0 {
|
||||
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
|
||||
user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
|
||||
t.Error("expected association error to be not nil")
|
||||
}
|
||||
}
|
||||
|
||||
@ -255,15 +255,3 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
|
||||
DB.Model(&pets).Association("Toy").Clear()
|
||||
AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
|
||||
user := User{Name: "jinzhu", Account: Account{Number: "1"}}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
|
||||
t.Error("expected association error to be not nil")
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -98,8 +95,6 @@ func TestMany2ManyAssociation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMany2ManyOmitAssociations(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
|
||||
|
||||
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
|
||||
@ -356,70 +351,3 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) {
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user2, findUser2)
|
||||
}
|
||||
|
||||
func TestConcurrentMany2ManyAssociation(t *testing.T) {
|
||||
db, err := OpenTestConnection(&gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open test connection failed, err: %+v", err)
|
||||
}
|
||||
|
||||
count := 3
|
||||
|
||||
var languages []Language
|
||||
for i := 0; i < count; i++ {
|
||||
language := Language{Code: fmt.Sprintf("consurrent %d", i)}
|
||||
db.Create(&language)
|
||||
languages = append(languages, language)
|
||||
}
|
||||
|
||||
user := User{}
|
||||
db.Create(&user)
|
||||
db.Preload("Languages").FirstOrCreate(&user)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func(user User, language Language) {
|
||||
err := db.Model(&user).Association("Languages").Append(&language)
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
wg.Done()
|
||||
}(user, languages[i])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var find User
|
||||
err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error
|
||||
AssertEqual(t, err, nil)
|
||||
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
|
||||
}
|
||||
|
||||
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
|
||||
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
|
||||
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
|
||||
ID: 1,
|
||||
Name: "Test-company-1",
|
||||
}},
|
||||
}}
|
||||
|
||||
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
|
||||
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
|
||||
ID: 1,
|
||||
Name: "Test-company-1",
|
||||
}},
|
||||
}}
|
||||
users := []*User{&user1, &user2}
|
||||
var err error
|
||||
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
|
||||
AssertEqual(t, nil, err)
|
||||
|
||||
var findUser1 User
|
||||
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user1, findUser1)
|
||||
|
||||
var findUser2 User
|
||||
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
|
||||
AssertEqual(t, nil, err)
|
||||
AssertEqual(t, user2, findUser2)
|
||||
}
|
||||
|
||||
@ -71,8 +71,6 @@ func TestAssociationNotNullClear(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestForeignKeyConstraints(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
@ -128,8 +126,6 @@ func TestForeignKeyConstraints(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type Profile struct {
|
||||
ID uint
|
||||
Name string
|
||||
@ -352,45 +348,3 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||
AssertEqual(t, len(orgs), 0)
|
||||
}
|
||||
}
|
||||
|
||||
type AssociationEmptyUser struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pets []AssociationEmptyPet
|
||||
}
|
||||
|
||||
type AssociationEmptyPet struct {
|
||||
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||
}
|
||||
|
||||
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
|
||||
id := uint(100)
|
||||
user := AssociationEmptyUser{
|
||||
ID: id,
|
||||
Name: "jinzhu",
|
||||
Pets: []AssociationEmptyPet{
|
||||
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create, got error: %v", err)
|
||||
}
|
||||
|
||||
var result AssociationEmptyUser
|
||||
err = DB.Preload("Pets").First(&result, &id).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
}
|
||||
|
||||
@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||
results: []string{"c1", "c3", "c4", "c5"},
|
||||
results: []string{"c1", "c5", "c3", "c4"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
@ -113,9 +113,6 @@ func TestCallbacks(t *testing.T) {
|
||||
|
||||
for idx, data := range datas {
|
||||
db, err := gorm.Open(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
callbacks := db.Callback()
|
||||
|
||||
for _, c := range data.callbacks {
|
||||
@ -206,49 +203,3 @@ func TestPluginCallbacks(t *testing.T) {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksGet(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||
}
|
||||
|
||||
createCallback.Remove("c1")
|
||||
if cb := createCallback.Get("c2"); cb != nil {
|
||||
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksRemove(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
createCallback.After("*").Register("c2", c2)
|
||||
createCallback.Before("c4").Register("c3", c3)
|
||||
createCallback.After("c2").Register("c4", c4)
|
||||
|
||||
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||
createCallback.Remove("c1")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c4")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c2")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c3")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,88 +0,0 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Man struct {
|
||||
ID int
|
||||
Age int
|
||||
Name string
|
||||
Detail string
|
||||
}
|
||||
|
||||
// Panic-safe BeforeUpdate hook that checks for Changed("age")
|
||||
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if !tx.Statement.Changed("age") {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Man) update(data interface{}) error {
|
||||
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
|
||||
}
|
||||
|
||||
func TestBeforeUpdateStatementChanged(t *testing.T) {
|
||||
DB.AutoMigrate(&Man{})
|
||||
type TestCase struct {
|
||||
BaseObjects Man
|
||||
change interface{}
|
||||
expectError bool
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
{
|
||||
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Age int
|
||||
}{Age: 20},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Name string
|
||||
}{Name: "name-only"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Name string
|
||||
Age int
|
||||
}{Name: "name-only", Age: 20},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
DB.Create(&test.BaseObjects)
|
||||
|
||||
// below comment is stored for future reference
|
||||
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
|
||||
err := test.BaseObjects.update(test.change)
|
||||
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
|
||||
if !test.expectError {
|
||||
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
|
||||
}
|
||||
} else {
|
||||
if test.expectError {
|
||||
t.Errorf("expected panic did not occur for input: %+v", test.change)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("unexpected GORM error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -48,11 +48,9 @@ func (c *wrapperConnPool) Ping() error {
|
||||
}
|
||||
|
||||
// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
|
||||
//
|
||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// return c.db.BeginTx(ctx, opts)
|
||||
// }
|
||||
//
|
||||
// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
// return c.db.BeginTx(ctx, opts)
|
||||
// }
|
||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||
tx, err := c.db.BeginTx(ctx, opts)
|
||||
@ -102,13 +100,13 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
expect: []string{
|
||||
"SELECT VERSION()",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
|
||||
},
|
||||
}
|
||||
|
||||
@ -118,8 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||
db.Logger = DB.Logger
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
||||
@ -11,32 +11,6 @@ import (
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestCountWithGroup(t *testing.T) {
|
||||
DB.Create([]Company{
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_a"},
|
||||
{Name: "company_count_group_b"},
|
||||
{Name: "company_count_group_c"},
|
||||
})
|
||||
|
||||
var count1 int64
|
||||
if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil {
|
||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
if count1 != 1 {
|
||||
t.Errorf("Count with group should be 1, but got count: %v", count1)
|
||||
}
|
||||
|
||||
var count2 int64
|
||||
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
|
||||
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||
}
|
||||
if count2 != 2 {
|
||||
t.Errorf("Count with group should be 2, but got count: %v", count2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
var (
|
||||
user1 = *GetUser("count-1", Config{})
|
||||
@ -168,7 +142,7 @@ func TestCount(t *testing.T) {
|
||||
DB.Create(sameUsers)
|
||||
|
||||
if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 {
|
||||
t.Fatalf("Count should be 1, but got count: %v err %v", count11, err)
|
||||
t.Fatalf("Count should be 3, but got count: %v err %v", count11, err)
|
||||
}
|
||||
|
||||
var count12 int64
|
||||
|
||||
@ -2,7 +2,6 @@ package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
@ -14,48 +13,31 @@ import (
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
u1 := *GetUser("create", Config{})
|
||||
user := *GetUser("create", Config{})
|
||||
|
||||
if results := DB.Create(&u1); results.Error != nil {
|
||||
if results := DB.Create(&user); results.Error != nil {
|
||||
t.Fatalf("errors happened when create: %v", results.Error)
|
||||
} else if results.RowsAffected != 1 {
|
||||
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||
}
|
||||
|
||||
if u1.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
|
||||
if user.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
||||
}
|
||||
|
||||
if u1.CreatedAt.IsZero() {
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Errorf("user's created at should be not zero")
|
||||
}
|
||||
|
||||
if u1.UpdatedAt.IsZero() {
|
||||
if user.UpdatedAt.IsZero() {
|
||||
t.Errorf("user's updated at should be not zero")
|
||||
}
|
||||
|
||||
var newUser User
|
||||
if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
|
||||
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
CheckUser(t, newUser, u1)
|
||||
}
|
||||
|
||||
type user struct {
|
||||
ID int `gorm:"primaryKey;->:false"`
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
var u2 user
|
||||
if results := DB.Create(&u2); results.Error != nil {
|
||||
t.Fatalf("errors happened when create: %v", results.Error)
|
||||
} else if results.RowsAffected != 1 {
|
||||
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||
}
|
||||
|
||||
if u2.ID != 0 {
|
||||
t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
|
||||
CheckUser(t, newUser, user)
|
||||
}
|
||||
}
|
||||
|
||||
@ -565,246 +547,3 @@ func TestFirstOrCreateRowsAffected(t *testing.T) {
|
||||
t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
|
||||
type CompositeKeyProduct struct {
|
||||
ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key
|
||||
LanguageCode int `gorm:"primaryKey;"` // primary key
|
||||
Code string
|
||||
Name string
|
||||
}
|
||||
|
||||
if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil {
|
||||
t.Fatalf("failed to migrate, got error %v", err)
|
||||
}
|
||||
if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil {
|
||||
t.Fatalf("failed to migrate, got error %v", err)
|
||||
}
|
||||
|
||||
prod := &CompositeKeyProduct{
|
||||
LanguageCode: 56,
|
||||
Code: "Code56",
|
||||
Name: "ProductName56",
|
||||
}
|
||||
if err := DB.Create(&prod).Error; err != nil {
|
||||
t.Fatalf("failed to create, got error %v", err)
|
||||
}
|
||||
|
||||
newProd := &CompositeKeyProduct{}
|
||||
if err := DB.First(&newProd).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
|
||||
type OnConflictUser struct {
|
||||
ID string
|
||||
Name string `gorm:"default:null"`
|
||||
Email string
|
||||
Mobile string `gorm:"default:'133xxxx'"`
|
||||
}
|
||||
|
||||
err := DB.Migrator().DropTable(&OnConflictUser{})
|
||||
AssertEqual(t, err, nil)
|
||||
err = DB.AutoMigrate(&OnConflictUser{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
u := OnConflictUser{
|
||||
ID: "on-conflict-user-id",
|
||||
Name: "on-conflict-user-name",
|
||||
Email: "on-conflict-user-email",
|
||||
Mobile: "on-conflict-user-mobile",
|
||||
}
|
||||
err = DB.Create(&u).Error
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
u.Name = "on-conflict-user-name-2"
|
||||
u.Email = "on-conflict-user-email-2"
|
||||
u.Mobile = ""
|
||||
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
var u2 OnConflictUser
|
||||
err = DB.Where("id = ?", u.ID).First(&u2).Error
|
||||
AssertEqual(t, err, nil)
|
||||
AssertEqual(t, u2.Name, "on-conflict-user-name-2")
|
||||
AssertEqual(t, u2.Email, "on-conflict-user-email-2")
|
||||
AssertEqual(t, u2.Mobile, "133xxxx")
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||
if !isMysql() {
|
||||
t.Skipf("This test case skipped, because of only supporting for mysql")
|
||||
}
|
||||
|
||||
// case 1: one record, create from map[string]interface{}
|
||||
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue1["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result1 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
var idVal int64
|
||||
_, ok := mapValue1["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue1["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result1.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case2: one record, create from *map[string]interface{}
|
||||
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
|
||||
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := mapValue2["id"]; !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result2 User
|
||||
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
_, ok = mapValue2["id"].(uint)
|
||||
if ok {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
idVal, ok = mapValue2["id"].(int64)
|
||||
if !ok {
|
||||
t.Fatal("ret result missing id")
|
||||
}
|
||||
|
||||
if int64(result2.ID) != idVal {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: records
|
||||
values := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
|
||||
}
|
||||
|
||||
beforeLen := len(values)
|
||||
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
|
||||
// mariadb with returning, values will be appended with id map
|
||||
if len(values) == beforeLen*2 {
|
||||
t.Skipf("This test case skipped, because the db supports returning")
|
||||
}
|
||||
|
||||
for i := range values {
|
||||
v, ok := values[i]["id"]
|
||||
if !ok {
|
||||
t.Fatal("failed to create data from map with table, returning map has no primary key")
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
if int64(result.ID) != v.(int64) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithTable(t *testing.T) {
|
||||
tableDB := DB.Table("users")
|
||||
supportLastInsertID := isMysql() || isSqlite()
|
||||
|
||||
// case 1: create from map[string]interface{}
|
||||
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
|
||||
if err := tableDB.Create(record).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
|
||||
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
|
||||
}
|
||||
|
||||
// case 2: create from *map[string]interface{}
|
||||
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
|
||||
tableDB2 := DB.Table("users")
|
||||
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
if _, ok := record1["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res1 map[string]interface{}
|
||||
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
// case 3: create from []map[string]interface{}
|
||||
records := []map[string]interface{}{
|
||||
{"name": "create_from_map_with_table_2", "age": 19},
|
||||
{"name": "create_from_map_with_table_3", "age": 20},
|
||||
}
|
||||
|
||||
tableDB = DB.Table("users")
|
||||
if err := tableDB.Create(&records).Error; err != nil {
|
||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
var res2 map[string]interface{}
|
||||
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
var res3 map[string]interface{}
|
||||
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
|
||||
}
|
||||
|
||||
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,22 +38,4 @@ func TestDefaultValue(t *testing.T) {
|
||||
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
|
||||
t.Fatalf("Failed to find created data with default data, got %+v", result)
|
||||
}
|
||||
|
||||
type Harumph2 struct {
|
||||
ID int `gorm:"default:0"`
|
||||
Email string `gorm:"not null;index:,unique"`
|
||||
Name string `gorm:"notNull;default:foo"`
|
||||
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
||||
Name3 string `gorm:"size:233;notNull;default:''"`
|
||||
Age int `gorm:"default:18"`
|
||||
Created time.Time `gorm:"default:2000-01-02"`
|
||||
Enabled bool `gorm:"default:true"`
|
||||
}
|
||||
|
||||
harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"}
|
||||
if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil {
|
||||
t.Fatalf("Failed to create data with default value, got error: %v", err)
|
||||
} else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" {
|
||||
t.Fatalf("Failed to create data with default value, got: %+v", harumph2)
|
||||
}
|
||||
}
|
||||
|
||||
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres, gaussdb, sqlserver support returning
|
||||
// only sqlite, postgres support returning
|
||||
func TestSoftDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: 'mysql:latest'
|
||||
image: 'mysql/mysql-server:latest'
|
||||
ports:
|
||||
- "127.0.0.1:9910:3306"
|
||||
- 9910:3306
|
||||
environment:
|
||||
- MYSQL_DATABASE=gorm
|
||||
- MYSQL_USER=gorm
|
||||
@ -11,22 +13,19 @@ services:
|
||||
postgres:
|
||||
image: 'postgres:latest'
|
||||
ports:
|
||||
- "127.0.0.1:9920:5432"
|
||||
- 9920:5432
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- POSTGRES_DB=gorm
|
||||
- POSTGRES_USER=gorm
|
||||
- POSTGRES_PASSWORD=gorm
|
||||
mssql:
|
||||
image: '${MSSQL_IMAGE}:latest'
|
||||
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
||||
ports:
|
||||
- "127.0.0.1:9930:1433"
|
||||
- 9930:1433
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- ACCEPT_EULA=Y
|
||||
- MSSQL_SA_PASSWORD=LoremIpsum86
|
||||
tidb:
|
||||
image: 'pingcap/tidb:v6.5.0'
|
||||
ports:
|
||||
- "127.0.0.1:9940:4000"
|
||||
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
|
||||
- SA_PASSWORD=LoremIpsum86
|
||||
- MSSQL_DB=gorm
|
||||
- MSSQL_USER=gorm
|
||||
- MSSQL_PASSWORD=LoremIpsum86
|
||||
@ -4,9 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -38,7 +36,7 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
|
||||
type EngadgetPost struct {
|
||||
BasePost BasePost `gorm:"Embedded"`
|
||||
Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
Author Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct
|
||||
ImageUrl string
|
||||
}
|
||||
|
||||
@ -76,26 +74,13 @@ func TestEmbeddedStruct(t *testing.T) {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}})
|
||||
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||
var egNews EngadgetPost
|
||||
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||
} else if egNews.BasePost.Title != "engadget_news" {
|
||||
t.Errorf("embedded struct's value should be scanned correctly")
|
||||
}
|
||||
|
||||
var egPosts []EngadgetPost
|
||||
if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil {
|
||||
t.Fatalf("no error should happen when query with embedded struct, but got %v", err)
|
||||
}
|
||||
expectAuthors := []string{"Edward", "George"}
|
||||
for i, post := range egPosts {
|
||||
t.Log(i, post.Author)
|
||||
if want := expectAuthors[i]; post.Author.Name != want {
|
||||
t.Errorf("expected author %s got %s", want, post.Author.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
@ -105,21 +90,9 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
URL string
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
Age int
|
||||
Content Content
|
||||
ContentPtr *Content
|
||||
Birthday time.Time
|
||||
BirthdayPtr *time.Time
|
||||
}
|
||||
|
||||
type HNPost struct {
|
||||
*BasePost
|
||||
Upvotes int32
|
||||
*Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&HNPost{})
|
||||
@ -137,52 +110,6 @@ func TestEmbeddedPointerTypeStruct(t *testing.T) {
|
||||
if hnPost.Title != "embedded_pointer_type" {
|
||||
t.Errorf("Should find correct value for embedded pointer type")
|
||||
}
|
||||
|
||||
if hnPost.Author != nil {
|
||||
t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author)
|
||||
}
|
||||
|
||||
now := time.Now().Round(time.Second)
|
||||
NewPost := HNPost{
|
||||
BasePost: &BasePost{Title: "embedded_pointer_type2"},
|
||||
Author: &Author{
|
||||
Name: "test",
|
||||
Content: Content{"test"},
|
||||
ContentPtr: nil,
|
||||
Birthday: now,
|
||||
BirthdayPtr: nil,
|
||||
},
|
||||
}
|
||||
DB.Create(&NewPost)
|
||||
|
||||
hnPost = HNPost{}
|
||||
if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil {
|
||||
t.Errorf("No error should happen when find embedded pointer type, but got %v", err)
|
||||
}
|
||||
|
||||
if hnPost.Title != NewPost.Title {
|
||||
t.Errorf("Should find correct value for embedded pointer type")
|
||||
}
|
||||
|
||||
if hnPost.Author.Name != NewPost.Author.Name {
|
||||
t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) {
|
||||
t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content)
|
||||
}
|
||||
|
||||
if hnPost.Author.ContentPtr != nil {
|
||||
t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr)
|
||||
}
|
||||
|
||||
if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() {
|
||||
t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday)
|
||||
}
|
||||
|
||||
if hnPost.Author.BirthdayPtr != nil {
|
||||
t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr)
|
||||
}
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
@ -190,26 +117,18 @@ type Content struct {
|
||||
}
|
||||
|
||||
func (c Content) Value() (driver.Value, error) {
|
||||
// mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530,
|
||||
b, err := json.Marshal(c)
|
||||
return string(b[:]), err
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *Content) Scan(src interface{}) error {
|
||||
var value Content
|
||||
str, ok := src.(string)
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
byt, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
if err := json.Unmarshal(byt, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal([]byte(str), &value); err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New("Embedded.Scan byte assertion failed")
|
||||
}
|
||||
|
||||
var value Content
|
||||
if err := json.Unmarshal(b, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = value
|
||||
@ -236,15 +155,8 @@ func TestEmbeddedScanValuer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEmbeddedRelations(t *testing.T) {
|
||||
type EmbUser struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Age uint
|
||||
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
|
||||
}
|
||||
|
||||
type AdvancedUser struct {
|
||||
EmbUser `gorm:"embedded"`
|
||||
User `gorm:"embedded"`
|
||||
Advanced bool
|
||||
}
|
||||
|
||||
@ -279,6 +191,6 @@ func TestEmbeddedTagSetting(t *testing.T) {
|
||||
err = DB.Save(&t1).Error
|
||||
AssertEqual(t, err, nil)
|
||||
if t1.Tag1.Id == 0 {
|
||||
t.Errorf("embedded struct's primary field should be rewritten")
|
||||
t.Errorf("embedded struct's primary field should be rewrited")
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,111 +0,0 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestDialectorWithErrorTranslatorSupport(t *testing.T) {
|
||||
// it shouldn't translate error when the TranslateError flag is false
|
||||
translatedErr := errors.New("translated error")
|
||||
untranslatedErr := errors.New("some random error")
|
||||
db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr})
|
||||
|
||||
err := db.AddError(untranslatedErr)
|
||||
if !errors.Is(err, untranslatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", untranslatedErr, err)
|
||||
}
|
||||
|
||||
// it should translate error when the TranslateError flag is true
|
||||
db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true})
|
||||
|
||||
err = db.AddError(untranslatedErr)
|
||||
if !errors.Is(err, translatedErr) {
|
||||
t.Fatalf("expected err: %v got err: %v", translatedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}); err != nil {
|
||||
t.Fatalf("failed to migrate cities table, got error: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create record: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&City{Name: "Kabul"}).Error
|
||||
if !errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
|
||||
tidbSkip(t, "not support the foreign key feature")
|
||||
|
||||
type City struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
type Museum struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
CityID uint
|
||||
City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"`
|
||||
}
|
||||
|
||||
db, err := OpenTestConnection(&gorm.Config{TranslateError: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&City{}, &Museum{})
|
||||
|
||||
if err = db.AutoMigrate(&City{}, &Museum{}); err != nil {
|
||||
t.Fatalf("failed to migrate countries & cities tables, got error: %v", err)
|
||||
}
|
||||
|
||||
city := City{Name: "Amsterdam"}
|
||||
|
||||
err = db.Create(&city).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create city: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create museum: %v", err)
|
||||
}
|
||||
|
||||
err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error
|
||||
if !errors.Is(err, gorm.ErrForeignKeyViolated) {
|
||||
t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err)
|
||||
}
|
||||
}
|
||||
@ -1,248 +0,0 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestGaussDBReturningIDWhichHasStringType(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Yasuo struct {
|
||||
// TODO: function gen_random_uuid() does not exist
|
||||
ID string `gorm:"default:gen_random_uuid()"`
|
||||
Name string
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Yasuo{})
|
||||
|
||||
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
yasuo := Yasuo{Name: "jinzhu"}
|
||||
if err := DB.Create(&yasuo).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
if yasuo.ID == "" {
|
||||
t.Fatal("should be able to has ID, but got zero value")
|
||||
}
|
||||
|
||||
var result Yasuo
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
yasuo.Name = "jinzhu1"
|
||||
if err := DB.Save(&yasuo).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDB(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Harumph struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"check:name_checker,name <> ''"`
|
||||
// TODO: function gen_random_uuid() does not exist
|
||||
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
Things pq.StringArray `gorm:"type:text[]"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Harumph{})
|
||||
|
||||
if err := DB.AutoMigrate(&Harumph{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
harumph := Harumph{}
|
||||
if err := DB.Create(&harumph).Error; err == nil {
|
||||
t.Fatalf("should failed to create data, name can't be blank")
|
||||
}
|
||||
|
||||
harumph = Harumph{Name: "jinzhu"}
|
||||
if err := DB.Create(&harumph).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
var result Harumph
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
harumph.Name = "jinzhu1"
|
||||
if err := DB.Save(&harumph).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("log_usage")
|
||||
|
||||
if err := DB.Exec(`
|
||||
CREATE TABLE public.log_usage (
|
||||
log_id bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||
SEQUENCE NAME public.log_usage_log_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("failed to create table, got error %v", err)
|
||||
}
|
||||
|
||||
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get columns, got error %v", err)
|
||||
}
|
||||
|
||||
hasLogID := false
|
||||
for _, column := range columns {
|
||||
if column.Name() == "log_id" {
|
||||
hasLogID = true
|
||||
autoIncrement, ok := column.AutoIncrement()
|
||||
if !ok || !autoIncrement {
|
||||
t.Fatalf("column log_id should be auto incrementment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLogID {
|
||||
t.Fatalf("failed to found column log_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
|
||||
t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
|
||||
DB.AutoMigrate(&Post{}, &Category{})
|
||||
|
||||
post := Post{
|
||||
Title: "Hello World",
|
||||
Categories: []*Category{
|
||||
{Title: "Coding"},
|
||||
{Title: "Golang"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&post).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBOnConstraint(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Thing struct {
|
||||
gorm.Model
|
||||
SomeID string
|
||||
OtherID string
|
||||
Data string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Thing{})
|
||||
DB.Migrator().CreateTable(&Thing{})
|
||||
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
thing := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something",
|
||||
}
|
||||
|
||||
DB.Create(&thing)
|
||||
|
||||
thing2 := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something else",
|
||||
}
|
||||
|
||||
result := DB.Clauses(clause.OnConflict{
|
||||
OnConstraint: "some_id_other_id_unique",
|
||||
UpdateAll: true,
|
||||
}).Create(&thing2)
|
||||
if result.Error != nil {
|
||||
t.Errorf("creating second thing: %v", result.Error)
|
||||
}
|
||||
|
||||
var things []Thing
|
||||
if err := DB.Find(&things).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(things) > 1 {
|
||||
t.Errorf("expected 1 thing got more")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBAlterColumnDataType(t *testing.T) {
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&Company{})
|
||||
DB.AutoMigrate(Company{})
|
||||
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||
}
|
||||
|
||||
DB.AutoMigrate(Company{})
|
||||
}
|
||||
@ -1,875 +0,0 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestGenericsCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := User{Name: "TestGenericsCreate", Age: 18}
|
||||
err := gorm.G[User](DB).Create(ctx, &user)
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Fatalf("no primary key found for %v", user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.ID != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.ID != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.Age != 0 {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != "" || u.Age != user.Age {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
result := struct {
|
||||
ID int
|
||||
Name string
|
||||
}{}
|
||||
if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
|
||||
t.Fatalf("failed to scan user, got error: %v", err)
|
||||
} else if result.Name != user.Name || uint(result.ID) != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", result, user)
|
||||
}
|
||||
|
||||
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
|
||||
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
|
||||
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsCreateInBatches(t *testing.T) {
|
||||
batch := []User{
|
||||
{Name: "GenericsCreateInBatches1"},
|
||||
{Name: "GenericsCreateInBatches2"},
|
||||
{Name: "GenericsCreateInBatches3"},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
for _, u := range batch {
|
||||
if u.ID == 0 {
|
||||
t.Fatalf("no primary key found for %v", u)
|
||||
}
|
||||
}
|
||||
|
||||
count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("expected 3 records, got %d", count)
|
||||
}
|
||||
|
||||
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||
if len(found) != len(batch) {
|
||||
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||
}
|
||||
|
||||
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
|
||||
if len(found) != 2 {
|
||||
t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
|
||||
}
|
||||
|
||||
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
|
||||
if len(found) != 1 {
|
||||
t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsExecAndUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
name := "GenericsExec"
|
||||
if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
|
||||
t.Fatalf("Exec insert failed: %v", err)
|
||||
}
|
||||
|
||||
u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != name || u.ID == 0 {
|
||||
t.Errorf("found invalid user, got %v", u)
|
||||
}
|
||||
|
||||
name += "Update"
|
||||
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
|
||||
if rows != 1 {
|
||||
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||
}
|
||||
|
||||
nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if nu.Name != name || u.ID != nu.ID {
|
||||
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||
}
|
||||
|
||||
rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
|
||||
if rows != 1 {
|
||||
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||
}
|
||||
|
||||
nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
|
||||
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsRow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := User{Name: "GenericsRow"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
|
||||
var name string
|
||||
if err := row.Scan(&name); err != nil {
|
||||
t.Fatalf("Row scan failed: %v", err)
|
||||
}
|
||||
if name != user.Name {
|
||||
t.Errorf("expected %s, got %s", user.Name, name)
|
||||
}
|
||||
|
||||
user2 := User{Name: "GenericsRow2"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Rows failed: %v", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
t.Fatalf("rows.Scan failed: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 rows, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
u := User{Name: "GenericsDelete"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
if rows != 1 {
|
||||
t.Errorf("expected 1 row deleted, got %d", rows)
|
||||
}
|
||||
|
||||
_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
t.Fatalf("User after delete failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsFindInBatches(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
users := []User{
|
||||
{Name: "GenericsFindBatchA"},
|
||||
{Name: "GenericsFindBatchB"},
|
||||
{Name: "GenericsFindBatchC"},
|
||||
{Name: "GenericsFindBatchD"},
|
||||
{Name: "GenericsFindBatchE"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
total := 0
|
||||
err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
|
||||
if len(chunk) > 2 {
|
||||
t.Errorf("batch size exceed 2: got %d", len(chunk))
|
||||
}
|
||||
|
||||
total += len(chunk)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("FindInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
if total != len(users) {
|
||||
t.Errorf("expected total %d, got %d", len(users), total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsScopes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
|
||||
err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
filterName1 := func(stmt *gorm.Statement) {
|
||||
stmt.Where("name = ?", "GenericsScopes1")
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Scopes failed: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Name != "GenericsScopes1" {
|
||||
t.Fatalf("Scopes expected 1, got %d", len(results))
|
||||
}
|
||||
|
||||
notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
|
||||
if len(notResult) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||
} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
|
||||
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
|
||||
}
|
||||
|
||||
orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
|
||||
if len(orResult) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||
} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
|
||||
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsJoins(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
|
||||
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
// Inner JOIN + WHERE
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Inner JOIN + WHERE with map
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where(map[string]any{"name": u.Company.Name})
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN w/o WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN + Alias WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Raw Subquery JOIN + WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw subquery join failed: %v", err)
|
||||
}
|
||||
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Raw Subquery JOIN + WHERE + Select
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw subquery join failed: %v", err)
|
||||
}
|
||||
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
return errors.New("join error")
|
||||
}).First(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("Joins should got error, but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedJoins(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "generics-nested-joins-1",
|
||||
Manager: &User{
|
||||
Name: "generics-nested-joins-manager-1",
|
||||
Company: Company{
|
||||
Name: "generics-nested-joins-manager-company-1",
|
||||
},
|
||||
NamedPet: &Pet{
|
||||
Name: "generics-nested-joins-manager-namepet-1",
|
||||
Toy: Toy{
|
||||
Name: "generics-nested-joins-manager-namepet-toy-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
|
||||
},
|
||||
{
|
||||
Name: "generics-nested-joins-2",
|
||||
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
db.CreateInBatches(ctx, &users, 100)
|
||||
|
||||
var userIDs []uint
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
|
||||
Where(map[string]any{"id": userIDs}).Find(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||
} else if len(users2) != len(users) {
|
||||
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||
}
|
||||
|
||||
sort.Slice(users2, func(i, j int) bool {
|
||||
return users2[i].ID > users2[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].ID > users[j].ID
|
||||
})
|
||||
|
||||
for idx, user := range users {
|
||||
// user
|
||||
CheckUser(t, user, users2[idx])
|
||||
if users2[idx].Manager == nil {
|
||||
t.Fatalf("Failed to load Manager")
|
||||
}
|
||||
// manager
|
||||
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||
// user pet
|
||||
if users2[idx].NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||
// manager pet
|
||||
if users2[idx].Manager.NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsPreloads(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
|
||||
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
|
||||
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
|
||||
names := []string{u.Name, u2.Name, u3.Name}
|
||||
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
db.Where("name = ?", u.Company.Name)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
} else if result.Company.Name != "" {
|
||||
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
return errors.New("preload error")
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("Preload should failed, but got nil")
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "mysql" {
|
||||
// mysql 5.7 doesn't support row_number()
|
||||
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||
return
|
||||
}
|
||||
}
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name desc").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name < result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name")
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
if len(result.Friends) != len(u.Friends) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedPreloads(t *testing.T) {
|
||||
user := *GetUser("generics_nested_preload", Config{Pets: 2})
|
||||
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
for idx, pet := range user.Pets {
|
||||
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
|
||||
}
|
||||
|
||||
if err := db.Create(ctx, &user); err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user2, user)
|
||||
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
|
||||
t.Fatalf("failed to nested preload")
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "mysql" {
|
||||
// mysql 5.7 doesn't support row_number()
|
||||
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||
return
|
||||
}
|
||||
}
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
|
||||
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(3)
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user3, user)
|
||||
|
||||
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
|
||||
t.Errorf("failed to nested preload with limit per record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDistinct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
batch := []User{
|
||||
{Name: "GenericsDistinctDup"},
|
||||
{Name: "GenericsDistinctDup"},
|
||||
{Name: "GenericsDistinctUnique"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Distinct Find failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected 2 distinct names, got %d", len(results))
|
||||
}
|
||||
|
||||
var names []string
|
||||
for _, u := range results {
|
||||
names = append(names, u.Name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
|
||||
if !reflect.DeepEqual(names, expected) {
|
||||
t.Errorf("expected names %v, got %v", expected, names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsGroupHaving(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
batch := []User{
|
||||
{Name: "GenericsGroupHavingMulti"},
|
||||
{Name: "GenericsGroupHavingMulti"},
|
||||
{Name: "GenericsGroupHavingSingle"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Group+Having Find failed: %v", err)
|
||||
}
|
||||
|
||||
if len(grouped) != 1 {
|
||||
t.Errorf("expected 1 group with count>1, got %d", len(grouped))
|
||||
} else if grouped[0].Name != "GenericsGroupHavingMulti" {
|
||||
t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsSubQuery(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{
|
||||
{Name: "GenericsSubquery_1", Age: 10},
|
||||
{Name: "GenericsSubquery_2", Age: 20},
|
||||
{Name: "GenericsSubquery_3", Age: 30},
|
||||
{Name: "GenericsSubquery_4", Age: 40},
|
||||
}
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Errorf("Four users should be found, instead found %d", len(results))
|
||||
}
|
||||
|
||||
results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 3 {
|
||||
t.Errorf("Three users should be found, instead found %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsUpsert(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
}
|
||||
|
||||
lang3 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "code"}},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
|
||||
}).Create(ctx, &lang3); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
} else if langs[0].Name != "upsert-new" {
|
||||
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsWithResult(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
|
||||
|
||||
result := gorm.WithResult()
|
||||
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create users WithResult")
|
||||
}
|
||||
|
||||
if result.RowsAffected != 2 {
|
||||
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsReuse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
|
||||
|
||||
err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create users")
|
||||
}
|
||||
|
||||
reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
|
||||
|
||||
sg := sync.WaitGroup{}
|
||||
for i := 0; i < 5; i++ {
|
||||
sg.Add(1)
|
||||
|
||||
go func() {
|
||||
if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
|
||||
}
|
||||
|
||||
if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
|
||||
}
|
||||
|
||||
if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if len(users) != 2 {
|
||||
t.Errorf("should find 2 users, but got %d", len(users))
|
||||
}
|
||||
sg.Done()
|
||||
}()
|
||||
}
|
||||
sg.Wait()
|
||||
}
|
||||
|
||||
func TestGenericsWithTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatalf("failed to begin transaction: %v", tx.Error)
|
||||
}
|
||||
|
||||
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
|
||||
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
|
||||
|
||||
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 records, got %d", count)
|
||||
}
|
||||
|
||||
if err := tx.Rollback().Error; err != nil {
|
||||
t.Fatalf("failed to rollback transaction: %v", err)
|
||||
}
|
||||
|
||||
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count2 != 0 {
|
||||
t.Errorf("expected 0 records after rollback, got %d", count2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsToSQL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
gorm.G[User](tx).Limit(10).Find(ctx)
|
||||
return tx
|
||||
})
|
||||
|
||||
if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
|
||||
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsScanUUID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{
|
||||
{Name: uuid.NewString(), Age: 21},
|
||||
{Name: uuid.NewString(), Age: 22},
|
||||
{Name: uuid.NewString(), Age: 23},
|
||||
}
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
userIds := []uuid.UUID{}
|
||||
if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 {
|
||||
t.Fatalf("Scan failed: %v, userids %v", err, userIds)
|
||||
}
|
||||
|
||||
if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name {
|
||||
t.Fatalf("wrong uuid scanned")
|
||||
}
|
||||
}
|
||||
44
tests/go.mod
44
tests/go.mod
@ -1,40 +1,20 @@
|
||||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.23.0
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.10.0
|
||||
gorm.io/driver/gaussdb v0.1.0
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/driver/sqlserver v1.6.1
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.9.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tjfoc/gmsm v1.4.1 // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.7
|
||||
github.com/mattn/go-sqlite3 v1.14.15 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220919173607-35f4265a4bc0 // indirect
|
||||
gorm.io/driver/mysql v1.3.6
|
||||
gorm.io/driver/postgres v1.3.10
|
||||
gorm.io/driver/sqlite v1.3.6
|
||||
gorm.io/driver/sqlserver v1.3.2
|
||||
gorm.io/gorm v1.23.9
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user