Compare commits

..

No commits in common. "master" and "v1.21.16" have entirely different histories.

157 changed files with 2250 additions and 16427 deletions

View File

@ -1,20 +0,0 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

View File

@ -1,31 +0,0 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,26 +0,0 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -3,26 +3,20 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v8 uses: actions/stale@v4
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
days-before-stale: 0 days-before-stale: 0
days-before-close: 30 days-before-close: 2
remove-stale-when-updated: true remove-stale-when-updated: true
only-labels: "type:invalid question" only-labels: "type:invalid question"

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests name: Label issues and pull requests
steps: steps:
- name: check out - name: check out
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: labeler - name: labeler
uses: jinzhu/super-labeler-action@develop uses: jinzhu/super-labeler-action@develop

View File

@ -3,25 +3,19 @@ on:
schedule: schedule:
- cron: "*/10 * * * *" - cron: "*/10 * * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v8 uses: actions/stale@v4
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
days-before-stale: 0 days-before-stale: 0
days-before-close: 30 days-before-close: 2
remove-stale-when-updated: true remove-stale-when-updated: true
only-labels: "type:missing reproduction steps" only-labels: "type:missing reproduction steps"

11
.github/workflows/reviewdog.yml vendored Normal file
View File

@ -0,0 +1,11 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v1
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2

View File

@ -3,25 +3,19 @@ on:
schedule: schedule:
- cron: "0 2 * * *" - cron: "0 2 * * *"
permissions:
contents: read
jobs: jobs:
stale: stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
ACTIONS_STEP_DEBUG: true ACTIONS_STEP_DEBUG: true
steps: steps:
- name: Close Stale Issues - name: Close Stale Issues
uses: actions/stale@v8 uses: actions/stale@v4
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days"
days-before-stale: 360 days-before-stale: 60
days-before-close: 180 days-before-close: 30
stale-issue-label: "status:stale" stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale' stale-pr-label: 'status:stale'

View File

@ -8,41 +8,38 @@ on:
branches-ignore: branches-ignore:
- 'gh-pages' - 'gh-pages'
permissions:
contents: read
jobs: jobs:
# Label of the container job # Label of the container job
sqlite: sqlite:
strategy: strategy:
matrix: matrix:
go: ['1.23', '1.24'] go: ['1.17', '1.16']
platform: [ubuntu-latest] # can not run in windows OS platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v4 uses: actions/setup-go@v2
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh run: GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql: mysql:
strategy: strategy:
matrix: matrix:
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7'] dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.23', '1.24'] go: ['1.17', '1.16']
platform: [ubuntu-latest] platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -65,70 +62,28 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v4 uses: actions/setup-go@v2
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh run: GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres: postgres:
strategy: strategy:
matrix: matrix:
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13'] dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.23', '1.24'] go: ['1.17', '1.16']
platform: [ubuntu-latest] # can not run in macOS and Windows platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
@ -151,40 +106,42 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v4 uses: actions/setup-go@v2
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh run: GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver: sqlserver:
strategy: strategy:
matrix: matrix:
go: ['1.23', '1.24'] go: ['1.17', '1.16']
platform: [ubuntu-latest] # can not run test in macOS and windows platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}
services: services:
mssql: mssql:
image: mcr.microsoft.com/mssql/server:2022-latest image: mcmoe/mssqldocker:latest
env: env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: LoremIpsum86 SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
ports: ports:
- 9930:1433 - 9930:1433
options: >- options: >-
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1" --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s --health-start-period 10s
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
@ -192,119 +149,18 @@ jobs:
steps: steps:
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v4 uses: actions/setup-go@v2
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: go mod package cache - name: go mod package cache
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests - name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh run: GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

2
.gitignore vendored
View File

@ -3,5 +3,3 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor
.vscode

View File

@ -1,19 +0,0 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- gocritic
- gosec
- ineffassign
- misspell
- prealloc
- unconvert
- unparam
- whitespace
formatters:
enable:
- gofumpt
- goimports

View File

@ -1,128 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com> Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -4,6 +4,9 @@ The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions) [![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
@ -27,18 +30,13 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License ## License
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)

View File

@ -14,7 +14,6 @@ import (
type Association struct { type Association struct {
DB *DB DB *DB
Relationship *schema.Relationship Relationship *schema.Relationship
Unscope bool
Error error Error error
} }
@ -41,15 +40,6 @@ func (db *DB) Association(column string) *Association {
return association return association
} }
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error { func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil { if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error association.Error = association.buildCondition().Find(out, conds...).Error
@ -74,30 +64,14 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations // save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error return association.Error
} }
// set old associations's foreign key to null // set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
if len(values) == 0 { if len(values) == 0 {
@ -105,10 +79,10 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
} }
case reflect.Struct: case reflect.Struct:
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
} }
for _, ref := range rel.References { for _, ref := range rel.References {
@ -117,20 +91,17 @@ func (association *Association) Replace(values ...interface{}) error {
association.Error = association.DB.UpdateColumns(updateMap).Error association.Error = association.DB.UpdateColumns(updateMap).Error
} }
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
var ( var (
primaryFields []*schema.Field primaryFields []*schema.Field
foreignKeys []string foreignKeys []string
updateMap = map[string]interface{}{} updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values}) tx.Not(clause.IN{Column: column, Values: values})
} }
@ -146,13 +117,9 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
if association.Unscope { association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
} }
case schema.Many2Many: case schema.Many2Many:
var ( var (
@ -176,14 +143,14 @@ func (association *Association) Replace(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values}) tx.Where(clause.IN{Column: column, Values: values})
} else { } else {
return ErrPrimaryKeyRequired return ErrPrimaryKeyRequired
} }
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
} }
@ -217,53 +184,29 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type { switch rel.Type {
case schema.BelongsTo: case schema.BelongsTo:
associationDB := association.DB.Session(&Session{}) tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany: case schema.HasOne, schema.HasMany:
model := reflect.New(rel.FieldSchema.ModelType).Interface() tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
if association.Unscope { association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many: case schema.Many2Many:
var ( var (
primaryFields, relPrimaryFields []*schema.Field primaryFields, relPrimaryFields []*schema.Field
@ -285,14 +228,11 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
@ -301,11 +241,11 @@ func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil { if association.Error == nil {
// clean up deleted values's foreign key // clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) { cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() { switch fieldValue.Kind() {
@ -313,7 +253,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ { for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -321,23 +261,23 @@ func (association *Association) Delete(values ...interface{}) error {
} }
} }
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) association.Error = rel.Field.Set(data, validFieldValues.Interface())
case reflect.Struct: case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields { for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) primaryValues[idx], _ = field.ValueOf(fieldValue)
} }
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break break
} }
if rel.JoinTable == nil { if rel.JoinTable == nil {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" { if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else { } else {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -389,18 +329,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() { switch rv.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() > 0 { if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
} }
} }
case reflect.Struct: case reflect.Struct:
if !rv.CanAddr() { association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct { if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -408,13 +344,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem() elemType := association.Relationship.Field.IndirectFieldType.Elem()
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
var fieldValue reflect.Value
if clear { if clear {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
} }
appendToFieldValues := func(ev reflect.Value) { appendToFieldValues := func(ev reflect.Value) {
@ -437,15 +369,11 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
} }
case reflect.Struct: case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr()) appendToFieldValues(rv.Addr())
} }
if association.Error == nil { if association.Error == nil {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
} }
} }
} }
@ -493,7 +421,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -501,7 +429,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
if association.Relationship.JoinTable == nil { if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err association.Error = err
break break
} }
@ -518,9 +446,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case? // TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
@ -528,12 +453,12 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
case reflect.Struct: case reflect.Struct:
// clear old data // clear old data
if clear && len(values) == 0 { if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil { if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References { for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} }
} }
} }
@ -542,9 +467,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values { for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value)) rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0) appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
} }
if len(values) > 0 { if len(values) > 0 {
@ -553,7 +475,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} }
for _, assignBack := range assignBacks { for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
if assignBack.Index > 0 { if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else { } else {
@ -564,21 +486,19 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
func (association *Association) buildCondition() *DB { func (association *Association) buildCondition() *DB {
var ( var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue) tx = association.DB.Model(modelValue)
) )
if association.Relationship.JoinTable != nil { if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses { for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE") joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 { tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
} }
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

View File

@ -75,7 +75,11 @@ func (cs *callbacks) Raw() *processor {
func (p *processor) Execute(db *DB) *DB { func (p *processor) Execute(db *DB) *DB {
// call scopes // call scopes
for len(db.Statement.scopes) > 0 { for len(db.Statement.scopes) > 0 {
db = db.executeScopes() scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
} }
var ( var (
@ -89,10 +93,6 @@ func (p *processor) Execute(db *DB) *DB {
resetBuildClauses = true resetBuildClauses = true
} }
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values // assign model values
if stmt.Model == nil { if stmt.Model == nil {
stmt.Model = stmt.Dest stmt.Model = stmt.Dest
@ -130,15 +130,9 @@ func (p *processor) Execute(db *DB) *DB {
f(db) f(db)
} }
if stmt.SQL.Len() > 0 { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
sql, vars := stmt.SQL.String(), stmt.Vars }, db.Error)
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun { if !stmt.DB.DryRun {
stmt.SQL.Reset() stmt.SQL.Reset()
@ -187,18 +181,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) { func (p *processor) compile() (err error) {
var callbacks []*callback var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks { for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) { if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback) callbacks = append(callbacks, callback)
} }
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
} }
p.callbacks = callbacks p.callbacks = callbacks
@ -257,14 +243,8 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string names, sorted []string
sortCallback func(*callback) error sortCallback func(*callback) error
) )
sort.SliceStable(cs, func(i, j int) bool { sort.Slice(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" { return cs[j].before == "*" || cs[j].after == "*"
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
}) })
for _, c := range cs { for _, c := range cs {
@ -347,14 +327,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return return
} }
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -7,7 +7,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func SaveBeforeAssociations(create bool) func(db *gorm.DB) { func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
@ -24,8 +23,8 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
setupReferences := func(obj reflect.Value, elem reflect.Value) { setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References { for _, ref := range rel.References {
if !ref.OwnPrimaryKey { if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) pv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) db.AddError(ref.ForeignKey.Set(obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv dest[ref.ForeignKey.DBName] = pv
@ -40,64 +39,49 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
rValLen = db.Statement.ReflectValue.Len() objs = make([]reflect.Value, 0, db.Statement.ReflectValue.Len())
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr isPtr = fieldType.Kind() == reflect.Ptr
) )
if !isPtr { if !isPtr {
fieldType = reflect.PointerTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
objs = append(objs, obj)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
}
} else {
break break
} }
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
}
}
} }
if elems.Len() > 0 { if elems.Len() > 0 {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ { for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i)) setupReferences(objs[i], elems.Index(i))
} }
} }
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv) setupReferences(db.Statement.ReflectValue, rv)
} }
} }
@ -126,7 +110,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
) )
if !isPtr { if !isPtr {
fieldType = reflect.PointerTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
@ -135,18 +119,18 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
obj := db.Statement.ReflectValue.Index(i) obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct { if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) db.AddError(ref.ForeignKey.Set(rv, fv))
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
} }
} }
@ -161,11 +145,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
} }
case reflect.Struct: case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr { if f.Kind() != reflect.Ptr {
f = f.Addr() f = f.Addr()
} }
@ -173,15 +157,15 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns := make([]string, 0, len(rel.References)) assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) ref.ForeignKey.Set(f, fv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) ref.ForeignKey.Set(f, ref.PrimaryValue)
} }
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
} }
} }
} }
@ -195,43 +179,28 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PointerTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) pv, _ := ref.PrimaryKey.ValueOf(v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) ref.ForeignKey.Set(elem, ref.PrimaryValue)
} }
} }
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) if isPtr {
for _, pf := range rel.FieldSchema.PrimaryFields { elems = reflect.Append(elems, elem)
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { } else {
relPrimaryValues = append(relPrimaryValues, pfv) elems = reflect.Append(elems, elem.Addr())
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
} }
} }
} }
@ -255,7 +224,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
} }
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
} }
} }
@ -268,57 +237,41 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem() fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr { if !isPtr {
fieldType = reflect.PointerTo(fieldType) fieldType = reflect.PtrTo(fieldType)
} }
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{} objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) { appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType) joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) fv, _ := ref.PrimaryKey.ValueOf(obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" { } else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else { } else {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) fv, _ := ref.PrimaryKey.ValueOf(elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) ref.ForeignKey.Set(joinValue, fv)
} }
} }
joins = reflect.Append(joins, joinValue) joins = reflect.Append(joins, joinValue)
} }
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) { appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ { for i := 0; i < f.Len(); i++ {
elem := f.Index(i) elem := f.Index(i)
if !isPtr {
elem = elem.Addr()
}
objs = append(objs, v) objs = append(objs, v)
elems = reflect.Append(elems, elem) if isPtr {
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) } else {
for _, pf := range rel.FieldSchema.PrimaryFields { elems = reflect.Append(elems, elem.Addr())
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
} }
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
} }
} }
} }
@ -338,7 +291,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length // optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 { if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v { if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
} }
for i := 0; i < elemLen; i++ { for i := 0; i < elemLen; i++ {
@ -357,7 +310,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames { for _, dbName := range s.PrimaryFieldDBNames {
@ -375,17 +328,11 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol
return return
} }
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var ( var (
selects, omits []string selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
refName = rel.Name + "." refName = rel.Name + "."
values = rValues.Interface()
) )
for name, ok := range selectColumns { for name, ok := range selectColumns {
@ -430,24 +377,3 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Val
return db.AddError(tx.Create(values).Error) return db.AddError(tx.Create(values).Error)
} }
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

View File

@ -13,6 +13,7 @@ var (
type Config struct { type Config struct {
LastInsertIDReversed bool LastInsertIDReversed bool
WithReturning bool
CreateClauses []string CreateClauses []string
QueryClauses []string QueryClauses []string
UpdateClauses []string UpdateClauses []string
@ -24,19 +25,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction return !db.SkipDefaultTransaction
} }
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create() createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:before_create", BeforeCreate)
@ -45,21 +33,30 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate) createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
createCallback.Clauses = config.CreateClauses createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query() queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Register("gorm:after_query", AfterQuery)
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
queryCallback.Clauses = config.QueryClauses queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete() deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config)) deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
deleteCallback.Clauses = config.DeleteClauses deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update() updateCallback := db.Callback().Update()
@ -67,10 +64,13 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config)) updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
updateCallback.Clauses = config.UpdateClauses updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row() rowCallback := db.Callback().Row()

View File

@ -13,20 +13,11 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0 db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx)
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++ db.Statement.CurDestIndex++
} }
case reflect.Struct: case reflect.Struct:
if db.Statement.ReflectValue.CanAddr() { fc(db.Statement.ReflectValue.Addr().Interface(), tx)
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
} }
} }
} }

View File

@ -3,15 +3,12 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) { func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -33,38 +30,23 @@ func BeforeCreate(db *gorm.DB) {
} }
} }
// Create create hook
func Create(config *Config) func(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING") if config.WithReturning {
return CreateWithReturning
}
return func(db *gorm.DB) { return func(db *gorm.DB) {
if db.Error != nil { if db.Error != nil {
return return
} }
if db.Statement.Schema != nil { if db.Statement.Schema != nil && !db.Statement.Unscoped {
if !db.Statement.Unscoped { for _, c := range db.Statement.Schema.CreateClauses {
for _, c := range db.Statement.Schema.CreateClauses { db.Statement.AddClause(c)
db.Statement.AddClause(c)
}
}
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
} }
} }
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.AddClause(ConvertToCreateValues(db.Statement))
@ -72,167 +54,172 @@ func Create(config *Config) func(db *gorm.DB) {
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
} }
isDryRun := !db.DryRun && db.Error == nil if !db.DryRun && db.Error == nil {
if !isDryRun { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
return
}
ok, mode := hasReturning(db, supportReturning) if err != nil {
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err) db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return return
} }
switch db.Statement.ReflectValue.Kind() { db.RowsAffected, _ = result.RowsAffected()
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv) if db.RowsAffected != 0 && db.Statement.Schema != nil &&
if isZero { db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
insertID -= pkField.AutoIncrementIncrement switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
if isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
} }
} }
} else { } else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ { db.AddError(err)
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
} }
} }
} }
} }
} }
// AfterCreate after create hooks func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for idx, field := range sch.FieldsWithDefaultDBValue {
if idx > 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(field.DBName)
}
if !db.DryRun && db.Error == nil {
db.RowsAffected = 0
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
c := db.Statement.Clauses["ON CONFLICT"]
onConflict, _ := c.Expression.(clause.OnConflict)
for rows.Next() {
BEGIN:
reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected))
if reflect.Indirect(reflectValue).Kind() != reflect.Struct {
break
}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
if onConflict.DoNothing && !fieldValue.IsZero() {
db.RowsAffected++
if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() {
return
}
goto BEGIN
}
values[idx] = fieldValue.Addr().Interface()
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
}
} else {
db.AddError(err)
}
}
} else if !db.DryRun && db.Error == nil {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterCreate(db *gorm.DB) { func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
called = true called = true
db.AddError(i.AfterSave(tx)) db.AddError(i.AfterSave(tx))
} }
} }
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
return called return called
}) })
} }
@ -271,18 +258,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len() stmt.SQL.Grow(stmt.ReflectValue.Len() * 18)
if rValLen == 0 { values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
if stmt.ReflectValue.Len() == 0 {
stmt.AddError(gorm.ErrEmptySlice) stmt.AddError(gorm.ErrEmptySlice)
return return
} }
stmt.SQL.Grow(rValLen * 18) for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i)) rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() { if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
@ -292,41 +276,39 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values[i] = make([]interface{}, len(values.Columns)) values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) field.Set(rv, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime)) field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) values.Values[i][idx], _ = field.ValueOf(rv)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime)) field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) values.Values[i][idx], _ = field.ValueOf(rv)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 { if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) defaultValueFieldsHavingValue[field] = make([]interface{}, stmt.ReflectValue.Len())
} }
defaultValueFieldsHavingValue[field][i] = rvOfvalue defaultValueFieldsHavingValue[field][i] = v
} }
} }
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for field, vs := range defaultValueFieldsHavingValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok { values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) for idx := range values.Values {
for idx := range values.Values { if vs[idx] == nil {
if vs[idx] == nil { values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) } else {
} else { values.Values[idx] = append(values.Values[idx], vs[idx])
values.Values[idx] = append(values.Values[idx], vs[idx])
}
} }
} }
} }
@ -334,25 +316,25 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns { for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name] field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
} }
} else if field.AutoUpdateTime > 0 && updateTrackTime { } else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
} }
} }
for _, field := range stmt.Schema.FieldsWithDefaultDBValue { for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue) values.Values[0] = append(values.Values[0], v)
} }
} }
} }
@ -370,15 +352,14 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
for _, column := range values.Columns { for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil { if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 { if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime { switch field.AutoUpdateTime {
case schema.UnixNanosecond: case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano() assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond: case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli() assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond: case schema.UnixSecond:
assignment.Value = curTime.Unix() assignment.Value = curTime.Unix()
} }
@ -393,9 +374,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
} }
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if len(onConflict.DoUpdates) == 0 {
onConflict.DoNothing = true
}
// use primary fields as default OnConflict columns // use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 { if len(onConflict.Columns) == 0 {

View File

@ -1,71 +0,0 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -7,7 +7,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func BeforeDelete(db *gorm.DB) { func BeforeDelete(db *gorm.DB) {
@ -26,110 +25,99 @@ func BeforeDelete(db *gorm.DB) {
func DeleteBeforeAssociations(db *gorm.DB) { func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil { if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
for column, v := range selectColumns { if restricted {
if !v { for column, v := range selectColumns {
continue if v {
} if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
rel, ok := db.Statement.Schema.Relationships.Relations[column] if len(db.Statement.Selects) > 0 {
if !ok { selects := make([]string, 0, len(db.Statement.Selects))
continue for _, s := range db.Statement.Selects {
} if s == clause.Associations {
selects = append(selects, s)
} else if strings.HasPrefix(s, column+".") {
selects = append(selects, strings.TrimPrefix(s, column+"."))
}
}
switch rel.Type { if len(selects) > 0 {
case schema.HasOne, schema.HasMany: tx = tx.Select(selects)
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) }
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() }
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 { for _, cond := range queryConds {
selects := make([]string, 0, len(db.Statement.Selects)) if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
for _, s := range db.Statement.Selects { withoutConditions = true
if s == clause.Associations { break
selects = append(selects, s) }
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { }
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
if !withoutConditions {
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
} }
} }
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
} }
} }
} }
} }
} }
func Delete(config *Config) func(db *gorm.DB) { func Delete(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses { for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(100) db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{}) db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -137,7 +125,7 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -147,36 +135,21 @@ func Delete(config *Config) func(db *gorm.DB) {
} }
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
} }
checkMissingWhereConditions(db) if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil { if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil { if err == nil {
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
} else {
if db.Statement.Result != nil { db.AddError(err)
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
} }
} }
} }

View File

@ -1,7 +1,6 @@
package callbacks package callbacks
import ( import (
"reflect"
"sort" "sort"
"gorm.io/gorm" "gorm.io/gorm"
@ -13,7 +12,7 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
values.Columns = make([]clause.Column, 0, len(mapValue)) values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
keys := make([]string, 0, len(mapValue)) var keys = make([]string, 0, len(mapValue))
for k := range mapValue { for k := range mapValue {
keys = append(keys, k) keys = append(keys, k)
} }
@ -41,7 +40,9 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
// ConvertSliceOfMapToValuesForCreate convert slice of map to values // ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValues)) var (
columns = make([]string, 0, len(mapValues))
)
// when the length of mapValues is zero,return directly here // when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method // no need to call stmt.SelectAndOmitColumns method
@ -92,61 +93,3 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
} }
return return
} }
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

View File

@ -1,157 +0,0 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

View File

@ -3,8 +3,6 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -12,179 +10,10 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// parsePreloadMap extracts nested preloads. e.g. func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var ( var (
reflectValue = tx.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string relForeignKeys []string
relForeignFields []*schema.Field relForeignFields []*schema.Field
foreignFields []*schema.Field foreignFields []*schema.Field
@ -193,6 +22,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{} inlineConds []interface{}
) )
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if rel.JoinTable != nil { if rel.JoinTable != nil {
var ( var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinForeignFields = make([]*schema.Field, 0, len(rel.References))
@ -214,28 +48,25 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(joinForeignValues) == 0 { if len(joinForeignValues) == 0 {
return nil return
} }
joinResults := rel.JoinTable.MakeSlice().Elem() joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
return err
}
// convert join identity map to relation identity map // convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields)) fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ { for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields { for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
} }
for idx, field := range joinRelForeignFields { for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -244,7 +75,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -260,9 +91,9 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(foreignValues) == 0 { if len(foreignValues) == 0 {
return nil return
} }
} }
@ -275,8 +106,6 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 { if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds { for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx) tx = fc(tx)
@ -285,13 +114,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
} }
} }
if len(inlineConds) > 0 { db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
} }
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
@ -301,17 +124,17 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
case reflect.Struct: case reflect.Struct:
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type { switch rel.Type {
case schema.HasMany, schema.Many2Many: case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
default: default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
} }
} }
} }
@ -319,33 +142,30 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) fieldValues[idx], _ = field.ValueOf(elem)
} }
datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok {
if !ok { for _, data := range datas {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) reflectFieldValue := rel.Field.ReflectValueOf(data)
} if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
for _, data := range datas { reflectFieldValue = reflect.Indirect(reflectFieldValue)
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) switch reflectFieldValue.Kind() {
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { case reflect.Struct:
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) rel.Field.Set(data, reflectResults.Index(i).Interface())
} case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
reflectFieldValue = reflect.Indirect(reflectFieldValue) rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
switch reflectFieldValue.Kind() { } else {
case reflect.Struct: rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) }
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
} }
} }
} else {
db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()))
} }
} }
return tx.Error
} }

View File

@ -3,12 +3,11 @@ package callbacks
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
@ -21,33 +20,28 @@ func Query(db *gorm.DB) {
db.AddError(err) db.AddError(err)
return return
} }
defer func() { defer rows.Close()
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
if db.Statement.Result != nil { gorm.Scan(rows, db, false)
db.Statement.Result.RowsAffected = db.RowsAffected
}
} }
} }
} }
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses { for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(100) db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct} clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields { for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
} }
} }
@ -101,169 +95,86 @@ func BuildQuerySQL(db *gorm.DB) {
} }
// inline joins // inline joins
fromClause := clause.From{} joins := []clause.Join{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v joins = fromClause.Joins
} }
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { if len(db.Statement.Joins) != 0 || len(joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames { for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
} }
} }
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins { for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil { if db.Statement.Schema == nil {
var isRelations bool // is relations or raw sql joins = append(joins, clause.Join{
var relations []*schema.Relationship Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] })
if ok { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
isRelations = true tableAliasName := relation.Name
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin { for _, s := range relation.FieldSchema.DBNames {
isRelations = true clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
relations = guessNestedRelations Table: tableAliasName,
} Name: s,
} Alias: tableAliasName + "__" + s,
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
if join.On != nil {
onStmt := gorm.Statement{Table: tableAliasName, DB: db}
join.On.Build(&onStmt)
onSQL := onStmt.SQL.String()
vars := onStmt.Vars
for idx, v := range onStmt.Vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else { } else {
fromClause.Joins = append(fromClause.Joins, clause.Join{ joins = append(joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
}) })
} }
} }
db.Statement.AddClause(fromClause) db.Statement.Joins = nil
db.Statement.AddClause(clause.From{Joins: joins})
} else { } else {
db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.AddClauseIfNotExists(clause.From{})
} }
@ -276,32 +187,49 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil { preloadMap := map[string]map[string][]interface{}{}
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) for name := range db.Statement.Preloads {
return preloadFields := strings.Split(name, ".")
if preloadFields[0] == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
if _, ok := preloadMap[rel.Name]; !ok {
preloadMap[rel.Name] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
}
}
}
} else {
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
} }
joins := make([]string, 0, len(db.Statement.Joins)) preloadNames := make([]string, 0, len(preloadMap))
for _, join := range db.Statement.Joins { for key := range preloadMap {
joins = append(joins, join.Name) preloadNames = append(preloadNames, key)
} }
sort.Strings(preloadNames)
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) for _, name := range preloadNames {
if tx.Error != nil { if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
return preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
} }
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
} }
} }
func AfterQuery(db *gorm.DB) { func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool { callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok { if i, ok := value.(AfterFindInterface); ok {

View File

@ -9,14 +9,8 @@ func RawExec(db *gorm.DB) {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil { if err != nil {
db.AddError(err) db.AddError(err)
return } else {
} db.RowsAffected, _ = result.RowsAffected()
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
} }
} }
} }

View File

@ -7,17 +7,16 @@ import (
func RowQuery(db *gorm.DB) { func RowQuery(db *gorm.DB) {
if db.Error == nil { if db.Error == nil {
BuildQuerySQL(db) BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
}
if isRows, ok := db.Get("rows"); ok && isRows.(bool) { if !db.DryRun {
db.Statement.Settings.Delete("rows") if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.Statement.Settings.Delete("rows")
} else { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else {
} db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1 db.RowsAffected = -1
}
} }
} }

View File

@ -5,7 +5,7 @@ import (
) )
func BeginTransaction(db *gorm.DB) { func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction && db.Error == nil { if !db.Config.SkipDefaultTransaction {
if tx := db.Begin(); tx.Error == nil { if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true) db.InstanceSet("gorm:started_transaction", true)
@ -20,12 +20,11 @@ func BeginTransaction(db *gorm.DB) {
func CommitOrRollbackTransaction(db *gorm.DB) { func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction { if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok { if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error != nil { if db.Error == nil {
db.Rollback()
} else {
db.Commit() db.Commit()
} else {
db.Rollback()
} }
db.Statement.ConnPool = db.ConnPool db.Statement.ConnPool = db.ConnPool
} }
} }

View File

@ -7,7 +7,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
func SetupUpdateReflectValue(db *gorm.DB) { func SetupUpdateReflectValue(db *gorm.DB) {
@ -21,7 +20,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo { for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok { if _, ok := dest[rel.Name]; ok {
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
} }
} }
} }
@ -29,7 +28,6 @@ func SetupUpdateReflectValue(db *gorm.DB) {
} }
} }
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) { func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
@ -52,78 +50,47 @@ func BeforeUpdate(db *gorm.DB) {
} }
} }
// Update update hook func Update(db *gorm.DB) {
func Update(config *Config) func(db *gorm.DB) { if db.Error != nil {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") return
}
return func(db *gorm.DB) { if db.Statement.Schema != nil && !db.Statement.Unscoped {
if db.Error != nil { for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return return
} }
db.Statement.Build(db.Statement.BuildClauses...)
}
if db.Statement.Schema != nil { if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
for _, c := range db.Statement.Schema.UpdateClauses { db.AddError(gorm.ErrMissingWhereClause)
db.Statement.AddClause(c) return
} }
}
if db.Statement.SQL.Len() == 0 { if !db.DryRun && db.Error == nil {
db.Statement.SQL.Grow(180) result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return
}
}
db.Statement.Build(db.Statement.BuildClauses...) if err == nil {
} db.RowsAffected, _ = result.RowsAffected()
} else {
checkMissingWhereConditions(db) db.AddError(err)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} }
} }
} }
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) { func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave { if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok { if i, ok := value.(AfterSaveInterface); ok {
called = true called = true
@ -131,6 +98,12 @@ func AfterUpdate(db *gorm.DB) {
} }
} }
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
return called return called
}) })
} }
@ -147,15 +120,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
if stmt.ReflectValue.CanAddr() { field.Set(stmt.ReflectValue.Index(i), value)
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
} }
} }
case reflect.Struct: case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) { assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() { if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue, value) field.Set(stmt.ReflectValue, value)
} }
} }
default: default:
@ -171,26 +142,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 { var primaryKeyExprs []clause.Expression
var isZero bool for i := 0; i < stmt.ReflectValue.Len(); i++ {
for i := 0; i < size; i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields { var notZero bool
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) for idx, field := range stmt.Schema.PrimaryFields {
if !isZero { value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
break exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
} notZero = notZero || !isZero
}
} }
if notZero {
if !isZero { primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
} }
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }
@ -243,25 +211,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.AutoUpdateTime == schema.UnixSecond { } else if field.GORMDataType == schema.Time {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} }
} }
} }
} }
} }
default: default:
updatingSchema := stmt.Schema var updatingSchema = stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema // different schema
updatingStmt := &gorm.Statement{DB: stmt.DB} updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil { if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema updatingSchema = updatingStmt.Schema
isDiffSchema = true
} }
} }
@ -272,33 +238,27 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field := updatingSchema.LookUpField(dbName); field != nil { if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(stmt.Context, updatingValue) value, isZero := field.ValueOf(updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond { if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano() value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond { } else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli() value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.AutoUpdateTime == schema.UnixSecond { } else if field.GORMDataType == schema.Time {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc() value = stmt.DB.NowFunc()
} else {
value = stmt.DB.NowFunc().Unix()
} }
isZero = false isZero = false
} }
if (ok || !isZero) && field.Updatable { if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignField := field assignValue(field, value)
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
} }
} }
} else { } else {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
} }
} }

View File

@ -10,11 +10,10 @@ import (
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
// // // update all users's name to `hello`
// // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello")
// db.Model(&User{}).Update("name", "hello") // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` // db.Model(&user).Update("name", "hello")
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) { func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Model = value tx.Statement.Model = value
@ -22,19 +21,6 @@ func (db *DB) Model(value interface{}) (tx *DB) {
} }
// Clauses Add clauses // Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
var whereConds []interface{} var whereConds []interface{}
@ -55,42 +41,27 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
return return
} }
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
// Table specify the table you would like to run db operations // Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) { func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
if results[1] != "" { tx.Statement.Table = results[1]
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
} }
} else if tables := strings.Split(name, "."); len(tables) == 2 { } else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1] tx.Statement.Table = tables[1]
} else if name != "" { } else {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
} }
return return
} }
// Distinct specify distinct fields that you want querying // Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) { func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Distinct = true tx.Statement.Distinct = true
@ -101,14 +72,6 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) {
} }
// Select specify fields that you want when querying, creating, updating // Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -127,11 +90,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
} }
delete(tx.Statement.Clauses, "SELECT")
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string: case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 { if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{ tx.Statement.AddClause(clause.Select{
@ -161,10 +120,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
} }
} }
if clause, ok := tx.Statement.Clauses["SELECT"]; ok { delete(tx.Statement.Clauses, "SELECT")
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
} }
default: default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
@ -185,25 +141,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return return
} }
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions // Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -213,11 +151,6 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
} }
// Not add NOT conditions // Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -227,11 +160,6 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
} }
// Or add OR conditions // Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
@ -241,45 +169,26 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
} }
// Joins specify Joins conditions // Joins specify Joins conditions
// // db.Joins("Account").Find(&user)
// db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(args) == 1 { if len(args) == 1 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
return
} }
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
} }
} }
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
return return
} }
// Group specify the group method on the find // Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) { func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -291,9 +200,6 @@ func (db *DB) Group(name string) (tx *DB) {
} }
// Having specify HAVING conditions for GROUP BY // Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{ tx.Statement.AddClause(clause.GroupBy{
@ -302,20 +208,13 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
return return
} }
// Order specify order when retrieving records from database // Order specify order when retrieve records from database
// // db.Order("name DESC")
// db.Order("name DESC") // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
switch v := value.(type) { switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn: case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{ tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v}, Columns: []clause.OrderByColumn{v},
@ -333,27 +232,13 @@ func (db *DB) Order(value interface{}) (tx *DB) {
} }
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit}) tx.Statement.AddClause(clause.Limit{Limit: limit})
return return
} }
// Offset specify the number of records to skip before starting to return the records // Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) { func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset}) tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -361,37 +246,25 @@ func (db *DB) Offset(offset int) (tx *DB) {
} }
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
// //
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000) // return func (db *gorm.DB) *gorm.DB {
// } // return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
// //
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...) tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx return tx
} }
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
return db
}
// Preload preload associations with given conditions // Preload preload associations with given conditions
// // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Preloads == nil { if tx.Statement.Preloads == nil {
@ -401,57 +274,18 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return return
} }
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.attrs = attrs tx.Statement.attrs = attrs
return return
} }
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) { func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.assigns = attrs tx.Statement.assigns = attrs
return return
} }
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) { func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Unscoped = true tx.Statement.Unscoped = true

View File

@ -29,12 +29,10 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{ clauses := []clause.Interface{
clause.Select{}, clause.Select{}, clause.From{},
clause.From{},
clause.Where{Exprs: []clause.Expression{ clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18}, clause.Gt{Column: "age", Value: 18},
@ -44,7 +42,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}}, }},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Limit: 10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
} }

View File

@ -20,7 +20,6 @@ type Builder interface {
Writer Writer
WriteQuoted(field interface{}) WriteQuoted(field interface{})
AddVar(Writer, ...interface{}) AddVar(Writer, ...interface{})
AddError(error) error
} }
// Clause // Clause

View File

@ -67,12 +67,6 @@ func (expr Expr) Build(builder Builder) {
builder.WriteByte(v) builder.WriteByte(v)
} }
} }
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
} }
// NamedExpr raw expression for named expr // NamedExpr raw expression for named expr
@ -126,8 +120,8 @@ func (expr NamedExpr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) { for _, v := range []byte(expr.SQL) {
if v == '@' && !inName { if v == '@' && !inName {
inName = true inName = true
name = name[:0] name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' {
if inName { if inName {
if nv, ok := namedMap[string(name)]; ok { if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv) builder.AddVar(builder, nv)
@ -246,19 +240,15 @@ func (eq Eq) Build(builder Builder) {
switch eq.Value.(type) { switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" IN (")
rv := reflect.ValueOf(eq.Value) rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 { for i := 0; i < rv.Len(); i++ {
builder.WriteString(" IN (NULL)") if i > 0 {
} else { builder.WriteByte(',')
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
} }
builder.WriteByte(')') builder.AddVar(builder, rv.Index(i).Interface())
} }
builder.WriteByte(')')
default: default:
if eqNil(eq.Value) { if eqNil(eq.Value) {
builder.WriteString(" IS NULL") builder.WriteString(" IS NULL")
@ -372,7 +362,7 @@ func (like Like) NegationBuild(builder Builder) {
} }
func eqNil(value interface{}) bool { func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { if valuer, ok := value.(driver.Valuer); ok {
value, _ = valuer.Value() value, _ = valuer.Value()
} }

View File

@ -94,16 +94,6 @@ func TestNamedExpr(t *testing.T) {
Vars: []interface{}{sql.Named("name", "jinzhu")}, Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?;", Result: "name1 = ? AND name2 = ?;",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r\n AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r\n AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, { }, {
SQL: "?", SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
@ -199,11 +189,6 @@ func TestExpression(t *testing.T) {
}, },
ExpectedVars: []interface{}{"a", "b"}, ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)", Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, { }, {
Expressions: []clause.Expression{ Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},

View File

@ -18,8 +18,7 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "role"}}, Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}},
}}, }},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
[]interface{}{"admin"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
@ -29,8 +28,7 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "gender"}}, Columns: []clause.Column{{Name: "gender"}},
Having: []clause.Expression{clause.Neq{"gender", "U"}}, Having: []clause.Expression{clause.Neq{"gender", "U"}},
}}, }},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
[]interface{}{"admin", "U"},
}, },
} }

View File

@ -1,7 +1,5 @@
package clause package clause
import "gorm.io/gorm/utils"
type JoinType string type JoinType string
const ( const (
@ -11,31 +9,7 @@ const (
RightJoin JoinType = "RIGHT" RightJoin JoinType = "RIGHT"
) )
type JoinTarget struct { // Join join clause for from
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType
Table Table Table Table
@ -44,12 +18,6 @@ type Join struct {
Expression Expression Expression Expression
} }
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) { func (join Join) Build(builder Builder) {
if join.Expression != nil { if join.Expression != nil {
join.Expression.Build(builder) join.Expression.Build(builder)

View File

@ -1,101 +0,0 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -1,8 +1,10 @@
package clause package clause
import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Limit *int Limit int
Offset int Offset int
} }
@ -13,16 +15,16 @@ func (limit Limit) Name() string {
// Build build where clause // Build build where clause
func (limit Limit) Build(builder Builder) { func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 { if limit.Limit > 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit) builder.WriteString(strconv.Itoa(limit.Limit))
} }
if limit.Offset > 0 { if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 { if limit.Limit > 0 {
builder.WriteByte(' ') builder.WriteString(" ")
} }
builder.WriteString("OFFSET ") builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset) builder.WriteString(strconv.Itoa(limit.Offset))
} }
} }
@ -31,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { if limit.Limit == 0 && v.Limit != 0 {
limit.Limit = v.Limit limit.Limit = v.Limit
} }

View File

@ -8,10 +8,6 @@ import (
) )
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
@ -19,56 +15,38 @@ func TestLimit(t *testing.T) {
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: &limit10, Limit: 10,
Offset: 20, Offset: 20,
}}, }},
"SELECT * FROM `users` LIMIT ? OFFSET ?", "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET ?", "SELECT * FROM `users` OFFSET 20", nil,
[]interface{}{20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET ?", "SELECT * FROM `users` OFFSET 30", nil,
[]interface{}{30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
"SELECT * FROM `users` LIMIT ? OFFSET ?", "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
[]interface{}{limit10, 20},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT ? OFFSET ?", "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
[]interface{}{limit10, 30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT ?", "SELECT * FROM `users` LIMIT 10", nil,
[]interface{}{limit10},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
"SELECT * FROM `users` OFFSET ?", "SELECT * FROM `users` OFFSET 30", nil,
[]interface{}{30},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
"SELECT * FROM `users` LIMIT ? OFFSET ?", "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
[]interface{}{limit50, 30},
}, },
} }

View File

@ -1,12 +1,5 @@
package clause package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct { type Locking struct {
Strength string Strength string
Table Table Table Table

View File

@ -14,21 +14,17 @@ func TestLocking(t *testing.T) {
Vars []interface{} Vars []interface{}
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
"SELECT * FROM `users` FOR UPDATE", nil, "SELECT * FROM `users` FOR UPDATE", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil, "SELECT * FROM `users` FOR SHARE OF `users`", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -16,27 +16,27 @@ func (OnConflict) Name() string {
// Build build onConflict clause // Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) { func (onConflict OnConflict) Build(builder Builder) {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
if onConflict.OnConstraint != "" { if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ") builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint) builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ') builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
} }
if onConflict.DoNothing { if onConflict.DoNothing {

View File

@ -45,8 +45,7 @@ func TestOrderBy(t *testing.T) {
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
}, },
}, },
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3},
[]interface{}{1, 2, 3},
}, },
} }

View File

@ -11,27 +11,20 @@ func (returning Returning) Name() string {
// Build build where clause // Build build where clause
func (returning Returning) Build(builder Builder) { func (returning Returning) Build(builder Builder) {
if len(returning.Columns) > 0 { for idx, column := range returning.Columns {
for idx, column := range returning.Columns { if idx > 0 {
if idx > 0 { builder.WriteByte(',')
builder.WriteByte(',')
}
builder.WriteQuoted(column)
} }
} else {
builder.WriteByte('*') builder.WriteQuoted(column)
} }
} }
// MergeClause merge order by clauses // MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) { func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 { if v, ok := clause.Expression.(Returning); ok {
if v.Columns != nil { returning.Columns = append(v.Columns, returning.Columns...)
returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
} }
clause.Expression = returning clause.Expression = returning
} }

View File

@ -26,22 +26,6 @@ func TestReturning(t *testing.T) {
}}, }},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -43,25 +43,6 @@ func TestSelect(t *testing.T) {
}, clause.From{}}, }, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
}, },
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{
clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`",
[]interface{}{18},
},
} }
for idx, result := range results { for idx, result := range results {

View File

@ -24,9 +24,9 @@ func (set Set) Build(builder Builder) {
builder.AddVar(builder, assignment.Value) builder.AddVar(builder, assignment.Value)
} }
} else { } else {
builder.WriteQuoted(Column{Name: PrimaryKey}) builder.WriteQuoted(PrimaryColumn)
builder.WriteByte('=') builder.WriteByte('=')
builder.WriteQuoted(Column{Name: PrimaryKey}) builder.WriteQuoted(PrimaryColumn)
} }
} }

View File

@ -20,8 +20,7 @@ func TestSet(t *testing.T) {
clause.Update{}, clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
}, },
"UPDATE `users` SET `users`.`id`=?", "UPDATE `users` SET `users`.`id`=?", []interface{}{1},
[]interface{}{1},
}, },
{ {
[]clause.Interface{ []clause.Interface{
@ -29,8 +28,7 @@ func TestSet(t *testing.T) {
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
}, },
"UPDATE `users` SET `name`=?", "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"},
[]interface{}{"jinzhu"},
}, },
} }

View File

@ -21,8 +21,7 @@ func TestValues(t *testing.T) {
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
}, },
}, },
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
[]interface{}{"jinzhu", 18, "josh", 1},
}, },
} }

View File

@ -4,11 +4,6 @@ import (
"strings" "strings"
) )
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause // Where where clause
type Where struct { type Where struct {
Exprs []Expression Exprs []Expression
@ -21,12 +16,6 @@ func (where Where) Name() string {
// Build build where clause // Build build where clause
func (where Where) Build(builder Builder) { func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition // Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs { for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
@ -37,7 +26,7 @@ func (where Where) Build(builder Builder) {
} }
} }
buildExprs(where.Exprs, builder, AndWithSpace) buildExprs(where.Exprs, builder, " AND ")
} }
func buildExprs(exprs []Expression, builder Builder, joinCond string) { func buildExprs(exprs []Expression, builder Builder, joinCond string) {
@ -46,7 +35,7 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
for idx, expr := range exprs { for idx, expr := range exprs {
if idx > 0 { if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(OrWithSpace) builder.WriteString(" OR ")
} else { } else {
builder.WriteString(joinCond) builder.WriteString(joinCond)
} }
@ -57,30 +46,27 @@ func buildExprs(exprs []Expression, builder Builder, joinCond string) {
case OrConditions: case OrConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL) sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
} }
} }
case AndConditions: case AndConditions:
if len(v.Exprs) == 1 { if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok { if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL) sql := strings.ToLower(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
} }
} }
case Expr: case Expr:
sql := strings.ToUpper(v.SQL) sql := strings.ToLower(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
} }
} }
if wrapInParentheses { if wrapInParentheses {
builder.WriteByte('(') builder.WriteString(`(`)
expr.Build(builder) expr.Build(builder)
builder.WriteByte(')') builder.WriteString(`)`)
wrapInParentheses = false wrapInParentheses = false
} else { } else {
expr.Build(builder) expr.Build(builder)
@ -103,14 +89,9 @@ func (where Where) MergeClause(clause *Clause) {
func And(exprs ...Expression) Expression { func And(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} else if len(exprs) == 1 {
return exprs[0]
} }
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs} return AndConditions{Exprs: exprs}
} }
@ -121,10 +102,10 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) { func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 { if len(and.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(and.Exprs, builder, AndWithSpace) buildExprs(and.Exprs, builder, " AND ")
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(and.Exprs, builder, AndWithSpace) buildExprs(and.Exprs, builder, " AND ")
} }
} }
@ -142,10 +123,10 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) { func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 { if len(or.Exprs) > 1 {
builder.WriteByte('(') builder.WriteByte('(')
buildExprs(or.Exprs, builder, OrWithSpace) buildExprs(or.Exprs, builder, " OR ")
builder.WriteByte(')') builder.WriteByte(')')
} else { } else {
buildExprs(or.Exprs, builder, OrWithSpace) buildExprs(or.Exprs, builder, " OR ")
} }
} }
@ -153,11 +134,6 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 { if len(exprs) == 0 {
return nil return nil
} }
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs} return NotConditions{Exprs: exprs}
} }
@ -166,67 +142,23 @@ type NotConditions struct {
} }
func (not NotConditions) Build(builder Builder) { func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false if len(not.Exprs) > 1 {
for _, c := range not.Exprs { builder.WriteByte('(')
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
} }
if anyNegationBuilder { for idx, c := range not.Exprs {
if len(not.Exprs) > 1 { if idx > 0 {
builder.WriteByte('(') builder.WriteString(" AND ")
} }
for idx, c := range not.Exprs { if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
if idx > 0 { negationBuilder.NegationBuild(builder)
builder.WriteString(AndWithSpace) } else {
} builder.WriteString("NOT ")
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr) e, wrapInParentheses := c.(Expr)
if wrapInParentheses { if wrapInParentheses {
sql := strings.ToUpper(e.SQL) sql := strings.ToLower(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses {
builder.WriteByte('(') builder.WriteByte('(')
} }
} }
@ -237,9 +169,9 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteByte(')') builder.WriteByte(')')
} }
} }
}
if len(not.Exprs) > 1 { if len(not.Exprs) > 1 {
builder.WriteByte(')') builder.WriteByte(')')
}
} }
} }

View File

@ -17,29 +17,25 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
[]interface{}{"1", 18, "jinzhu"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
[]interface{}{"1", "jinzhu", 18},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
[]interface{}{"1", "jinzhu", 18},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
[]interface{}{"1", "jinzhu"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -47,8 +43,7 @@ func TestWhere(t *testing.T) {
}, clause.Where{ }, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}}, }},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -56,78 +51,13 @@ func TestWhere(t *testing.T) {
}, clause.Where{ }, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}}, }},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{ []clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}}, }},
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?", "SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)", []interface{}{18, "jinzhu"},
[]interface{}{18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
}, },
} }

View File

@ -1,3 +1,4 @@
package clause package clause
type With struct{} type With struct {
}

View File

@ -21,10 +21,6 @@ var (
ErrPrimaryKeyRequired = errors.New("primary key required") ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required // ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required") ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data // ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data") ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver // ErrUnsupportedDriver unsupported driver
@ -43,12 +39,4 @@ var (
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length // ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
) )

View File

@ -1,11 +1,9 @@
package gorm package gorm
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"hash/maphash"
"reflect" "reflect"
"strings" "strings"
@ -15,7 +13,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// Create inserts value, returning the inserted data's primary key in value's id // Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) { func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 { if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize) return db.CreateInBatches(value, db.CreateBatchSize)
@ -26,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
// CreateInBatches inserts value in batches of batchSize // CreateInBatches insert the value in batches into database
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -35,10 +33,9 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
var rowsAffected int64 var rowsAffected int64
tx = db.getInstance() tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error { callFc := func(tx *DB) error {
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
for i := 0; i < reflectLen; i += batchSize { for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize ends := i + batchSize
if ends > reflectLen { if ends > reflectLen {
@ -56,7 +53,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return nil return nil
} }
if tx.SkipDefaultTransaction || reflectLen <= batchSize { if tx.SkipDefaultTransaction {
tx.AddError(callFc(tx.Session(&Session{}))) tx.AddError(callFc(tx.Session(&Session{})))
} else { } else {
tx.AddError(tx.Transaction(callFc)) tx.AddError(tx.Transaction(callFc))
@ -71,16 +68,12 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return return
} }
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. // Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
@ -90,7 +83,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
case reflect.Struct: case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields { for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { if _, isZero := pf.ValueOf(reflectValue); isZero {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
} }
@ -104,19 +97,20 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, "*") tx.Statement.Selects = append(tx.Statement.Selects, "*")
} }
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) tx = tx.callbacks.Update().Execute(tx)
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
} }
return updateTx
} }
return return
} }
// First finds the first record ordered by primary key, matching given conditions conds // First find first record that match given conditions, order by primary key
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -131,7 +125,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Take finds the first record returned by the database in no specified order, matching given conditions conds // Take return a record that match given conditions, the order will depend on the database implementation
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
@ -144,7 +138,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Last finds the last record ordered by primary key, matching given conditions conds // Last find last record that match given conditions, order by primary key
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -160,7 +154,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Find finds all records matching given conditions conds // Find find records that match given conditions
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -172,7 +166,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// FindInBatches finds all records in batches of batchSize // FindInBatches find records in batches
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var ( var (
tx = db.Order(clause.OrderByColumn{ tx = db.Order(clause.OrderByColumn{
@ -183,32 +177,13 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch int batch int
) )
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for { for {
result := queryDB.Limit(batchSize).Find(dest) result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected rowsAffected += result.RowsAffected
batch++ batch++
if result.Error == nil && result.RowsAffected != 0 { if result.Error == nil && result.RowsAffected != 0 {
fcTx := result.Session(&Session{NewDB: true}) tx.AddError(fc(result, batch))
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil { } else if result.Error != nil {
tx.AddError(result.Error) tx.AddError(result.Error)
} }
@ -217,15 +192,6 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break // Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest)) resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil { if result.Statement.Schema.PrioritizedPrimaryField == nil {
@ -233,11 +199,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
break break
} }
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
} }
@ -245,7 +207,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
return tx return tx
} }
func (db *DB) assignInterfacesToValue(values ...interface{}) { func (tx *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values { for _, value := range values {
switch v := value.(type) { switch v := value.(type) {
case []clause.Expression: case []clause.Expression:
@ -253,40 +215,40 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
if eq, ok := expr.(clause.Eq); ok { if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
if field := db.Statement.Schema.LookUpField(column); field != nil { if field := tx.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
} }
case clause.Column: case clause.Column:
if field := db.Statement.Schema.LookUpField(column.Name); field != nil { if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value))
} }
} }
} else if andCond, ok := expr.(clause.AndConditions); ok { } else if andCond, ok := expr.(clause.AndConditions); ok {
db.assignInterfacesToValue(andCond.Exprs) tx.assignInterfacesToValue(andCond.Exprs)
} }
} }
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs) tx.assignInterfacesToValue(exprs)
} }
default: default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
for _, f := range s.Fields { for _, f := range s.Fields {
if f.Readable { if f.Readable {
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { if v, isZero := f.ValueOf(reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil { if field := tx.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) tx.AddError(field.Set(tx.Statement.ReflectValue, v))
} }
} }
} }
} }
} }
} else if len(values) > 0 { } else if len(values) > 0 {
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs) tx.assignInterfacesToValue(exprs)
} }
return return
} }
@ -294,24 +256,12 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
@ -331,64 +281,40 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() queryTx := db.Limit(1).Order(clause.OrderByColumn{
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}) })
result := queryTx.Find(dest, conds...) if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if result.Error != nil { if c, ok := tx.Statement.Clauses["WHERE"]; ok {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok { if where, ok := c.Expression.(clause.Where); ok {
result.assignInterfacesToValue(where.Exprs) tx.assignInterfacesToValue(where.Exprs)
} }
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(db.Statement.attrs) > 0 { if len(tx.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...) tx.assignInterfacesToValue(tx.Statement.attrs...)
} }
// initialize with attrs, conds // initialize with attrs, conds
if len(db.Statement.assigns) > 0 { if len(tx.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...) tx.assignInterfacesToValue(tx.Statement.assigns...)
} }
return tx.Create(dest) return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 { } else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{} assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ { for _, expr := range exprs {
expr := exprs[i] if eq, ok := expr.(clause.Eq); ok {
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) { switch column := eq.Column.(type) {
case string: case string:
assigns[column] = eq.Value assigns[column] = eq.Value
case clause.Column: case clause.Column:
assigns[column.Name] = eq.Value assigns[column.Name] = eq.Value
default:
} }
} }
} }
@ -396,17 +322,17 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Model(dest).Updates(assigns) return tx.Model(dest).Updates(assigns)
} }
return tx return db
} }
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
@ -427,9 +353,7 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -495,11 +419,9 @@ func (db *DB) Count(count *int64) (tx *DB) {
tx.Statement.Dest = count tx.Statement.Dest = count
tx = tx.callbacks.Query().Execute(tx) tx = tx.callbacks.Query().Execute(tx)
if tx.RowsAffected != 1 {
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected *count = tx.RowsAffected
} }
return return
} }
@ -523,7 +445,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error return rows, tx.Error
} }
// Scan scans selected value to the struct dest // Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New() currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -532,14 +454,15 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Config = &config tx.Config = &config
if rows, err := tx.Rows(); err == nil { if rows, err := tx.Rows(); err != nil {
tx.AddError(err)
} else {
defer rows.Close()
if rows.Next() { if rows.Next() {
tx.ScanRows(rows, dest) tx.ScanRows(rows, dest)
} else { } else {
tx.RowsAffected = 0 tx.RowsAffected = 0
tx.AddError(rows.Err())
} }
tx.AddError(rows.Close())
} }
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
@ -549,10 +472,9 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return return
} }
// Pluck queries a single column from a model, returning in the slice dest. E.g.: // Pluck used to query single column from a model as a map
// // var ages []int64
// var ages []int64 // db.Model(&users).Pluck("age", &ages)
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if tx.Statement.Model != nil { if tx.Statement.Model != nil {
@ -589,60 +511,31 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
} }
tx.Statement.ReflectValue = elem tx.Statement.ReflectValue = elem
} }
Scan(rows, tx, ScanInitialized) Scan(rows, tx, true)
return tx.Error return tx.Error
} }
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is // Transaction start a transaction as a block, return error will rollback, otherwise to commit.
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction // nested transaction
if !db.DisableNestedTransaction { if !db.DisableNestedTransaction {
spID := new(maphash.Hash).Sum64() err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() { defer func() {
// Make sure to rollback when panic, Block error or Commit error // Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil { if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%d", spID)) db.RollbackTo(fmt.Sprintf("sp%p", fc))
} }
}() }()
} }
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
if err == nil {
err = fc(db.Session(&Session{}))
}
} else { } else {
tx := db.Begin(opts...) tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() { defer func() {
// Make sure to rollback when panic, Block error or Commit error // Make sure to rollback when panic, Block error or Commit error
@ -651,9 +544,12 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
} }
}() }()
if err = fc(tx); err == nil { if err = tx.Error; err == nil {
panicked = false err = fc(tx)
return tx.Commit().Error }
if err == nil {
err = tx.Commit().Error
} }
} }
@ -661,11 +557,11 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return return
} }
// Begin begins a transaction with any transaction options opts // Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var ( var (
// clone statement // clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) tx = db.getInstance().Session(&Session{Context: db.Statement.Context})
opt *sql.TxOptions opt *sql.TxOptions
err error err error
) )
@ -674,19 +570,11 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0] opt = opts[0]
} }
ctx := tx.Statement.Context if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
if _, ok := ctx.Deadline(); !ok { tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
if db.Config.DefaultTransactionTimeout > 0 { } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} } else {
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }
@ -697,7 +585,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx return tx
} }
// Commit commits the changes in a transaction // Commit commit a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
@ -707,7 +595,7 @@ func (db *DB) Commit() *DB {
return db return db
} }
// Rollback rollbacks the changes in a transaction // Rollback rollback a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() { if !reflect.ValueOf(committer).IsNil() {
@ -721,21 +609,7 @@ func (db *DB) Rollback() *DB {
func (db *DB) SavePoint(name string) *DB { func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name)) db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }
@ -744,28 +618,14 @@ func (db *DB) SavePoint(name string) *DB {
func (db *DB) RollbackTo(name string) *DB { func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name)) db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else { } else {
db.AddError(ErrUnsupportedDriver) db.AddError(ErrUnsupportedDriver)
} }
return db return db
} }
// Exec executes raw sql // Exec execute raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}

View File

@ -1,605 +0,0 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

5
go.mod
View File

@ -1,9 +1,8 @@
module gorm.io/gorm module gorm.io/gorm
go 1.18 go 1.14
require ( require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.2
golang.org/x/text v0.20.0
) )

6
go.sum
View File

@ -1,6 +1,4 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

157
gorm.go
View File

@ -4,7 +4,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -21,9 +20,7 @@ const preparedStmtDBKey = "preparedStmt"
type Config struct { type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true // You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer NamingStrategy schema.Namer
// FullSaveAssociations full save associations // FullSaveAssociations full save associations
@ -36,17 +33,10 @@ type Config struct {
DryRun bool DryRun bool
// PrepareStmt executes the given query in cached statement // PrepareStmt executes the given query in cached statement
PrepareStmt bool PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing // DisableAutomaticPing
DisableAutomaticPing bool DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating // DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction // DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool DisableNestedTransaction bool
// AllowGlobalUpdate allow global update // AllowGlobalUpdate allow global update
@ -55,10 +45,6 @@ type Config struct {
QueryFields bool QueryFields bool
// CreateBatchSize default create batch size // CreateBatchSize default create batch size
CreateBatchSize int CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
@ -73,7 +59,6 @@ type Config struct {
cacheStore *sync.Map cacheStore *sync.Map
} }
// Apply update config to new config
func (c *Config) Apply(config *Config) error { func (c *Config) Apply(config *Config) error {
if config != c { if config != c {
*config = *c *config = *c
@ -81,7 +66,6 @@ func (c *Config) Apply(config *Config) error {
return nil return nil
} }
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error { func (c *Config) AfterInitialize(db *DB) error {
if db != nil { if db != nil {
for _, plugin := range c.Plugins { for _, plugin := range c.Plugins {
@ -93,7 +77,6 @@ func (c *Config) AfterInitialize(db *DB) error {
return nil return nil
} }
// Option gorm option interface
type Option interface { type Option interface {
Apply(*Config) error Apply(*Config) error
AfterInitialize(*DB) error AfterInitialize(*DB) error
@ -113,13 +96,11 @@ type Session struct {
DryRun bool DryRun bool
PrepareStmt bool PrepareStmt bool
NewDB bool NewDB bool
Initialized bool
SkipHooks bool SkipHooks bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
DisableNestedTransaction bool DisableNestedTransaction bool
AllowGlobalUpdate bool AllowGlobalUpdate bool
FullSaveAssociations bool FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool QueryFields bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
@ -137,24 +118,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
return isConfig && !isConfig2 return isConfig && !isConfig2
}) })
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts { for _, opt := range opts {
if opt != nil { if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil { if err := opt.Apply(config); err != nil {
return nil, applyErr return nil, err
} }
defer func(opt Option) { defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil { if errr := opt.AfterInitialize(db); errr != nil {
err = errr err = errr
} }
@ -169,7 +138,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
} }
if config.NamingStrategy == nil { if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 config.NamingStrategy = schema.NamingStrategy{}
} }
if config.Logger == nil { if config.Logger == nil {
@ -202,26 +171,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if config.Dialector != nil { if config.Dialector != nil {
err = config.Dialector.Initialize(db) err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
} }
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
if config.PrepareStmt { if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt db.ConnPool = preparedStmt
} }
@ -272,10 +232,6 @@ func (db *DB) Session(config *Session) *DB {
txConfig.FullSaveAssociations = true txConfig.FullSaveAssociations = true
} }
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks { if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone() tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx tx.Statement.DB = tx
@ -286,30 +242,16 @@ func (db *DB) Session(config *Session) *DB {
} }
if config.PrepareStmt { if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB) preparedStmt := v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux, Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts, Stmts: preparedStmt.Stmts,
} }
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
} }
if config.SkipHooks { if config.SkipHooks {
@ -340,10 +282,6 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc tx.Config.NowFunc = config.NowFunc
} }
if config.Initialized {
tx = tx.getInstance()
}
return tx return tx
} }
@ -354,8 +292,7 @@ func (db *DB) WithContext(ctx context.Context) *DB {
// Debug start debug mode // Debug start debug mode
func (db *DB) Debug() (tx *DB) { func (db *DB) Debug() (tx *DB) {
tx = db.getInstance() return db.Session(&Session{
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info), Logger: db.Logger.LogMode(logger.Info),
}) })
} }
@ -391,18 +328,10 @@ func (db *DB) Callback() *callbacks {
// AddError add error to db // AddError add error to db
func (db *DB) AddError(err error) error { func (db *DB) AddError(err error) error {
if err != nil { if db.Error == nil {
if db.Config.TranslateError { db.Error = err
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { } else if err != nil {
err = errTranslator.Translate(err) db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
} }
return db.Error return db.Error
} }
@ -410,20 +339,12 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB` // DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) { func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { return dbConnector.GetDBConn()
return sqldb, err
}
} }
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { if sqldb, ok := connPool.(*sql.DB); ok {
return sqldb, nil return sqldb, nil
} }
@ -437,15 +358,11 @@ func (db *DB) getInstance() *DB {
if db.clone == 1 { if db.clone == 1 {
// clone with new statement // clone with new statement
tx.Statement = &Statement{ tx.Statement = &Statement{
DB: tx, DB: tx,
ConnPool: db.Statement.ConnPool, ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context, Context: db.Statement.Context,
Clauses: map[string]clause.Clause{}, Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8), Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
} }
} else { } else {
// with clone statement // with clone statement
@ -459,12 +376,10 @@ func (db *DB) getInstance() *DB {
return db return db
} }
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr { func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args} return clause.Expr{SQL: expr, Vars: args}
} }
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var ( var (
tx = db.getInstance() tx = db.getInstance()
@ -487,7 +402,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
relation, ok := modelSchema.Relationships.Relations[field] relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil isRelation := ok && relation.JoinTable != nil
if !isRelation { if !isRelation {
return fmt.Errorf("failed to find relation: %s", field) return fmt.Errorf("failed to found relation: %s", field)
} }
for _, ref := range relation.References { for _, ref := range relation.References {
@ -515,7 +430,6 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
return nil return nil
} }
// Use use plugin
func (db *DB) Use(plugin Plugin) error { func (db *DB) Use(plugin Plugin) error {
name := plugin.Name() name := plugin.Name()
if _, ok := db.Plugins[name]; ok { if _, ok := db.Plugins[name]; ok {
@ -527,18 +441,3 @@ func (db *DB) Use(plugin Plugin) error {
db.Plugins[name] = plugin db.Plugins[name] = plugin
return nil return nil
} }
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}

View File

@ -26,10 +26,6 @@ type Plugin interface {
Initialize(*DB) error Initialize(*DB) error
} }
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface // ConnPool db conns pool interface
type ConnPool interface { type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
@ -44,49 +40,24 @@ type SavePointerDialectorInterface interface {
RollbackTo(tx *DB, name string) error RollbackTo(tx *DB, name string) error
} }
// TxBeginner tx beginner
type TxBeginner interface { type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface { type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
} }
// TxCommitter tx committer
type TxCommitter interface { type TxCommitter interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface // Valuer gorm valuer interface
type Valuer interface { type Valuer interface {
GormValue(context.Context, *DB) clause.Expr GormValue(context.Context, *DB) clause.Expr
} }
// GetDBConnector SQL db connector
type GetDBConnector interface { type GetDBConnector interface {
GetDBConn() (*sql.DB, error) GetDBConn() (*sql.DB, error)
} }
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
}
type ErrorTranslator interface {
Translate(err error) error
}

View File

@ -1,493 +0,0 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -1,183 +0,0 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io" "io/ioutil"
"log" "log"
"os" "os"
"time" "time"
@ -12,7 +12,6 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found") var ErrRecordNotFound = errors.New("record not found")
// Colors // Colors
@ -31,17 +30,13 @@ const (
YellowBold = "\033[33;1m" YellowBold = "\033[33;1m"
) )
// LogLevel log level // LogLevel
type LogLevel int type LogLevel int
const ( const (
// Silent silent log level
Silent LogLevel = iota + 1 Silent LogLevel = iota + 1
// Error error log level
Error Error
// Warn warn log level
Warn Warn
// Info info log level
Info Info
) )
@ -50,12 +45,10 @@ type Writer interface {
Printf(string, ...interface{}) Printf(string, ...interface{})
} }
// Config logger config
type Config struct { type Config struct {
SlowThreshold time.Duration SlowThreshold time.Duration
Colorful bool Colorful bool
IgnoreRecordNotFoundError bool IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel LogLevel LogLevel
} }
@ -69,25 +62,16 @@ type Interface interface {
} }
var ( var (
// Discard logger will print any log to io.Discard Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond, SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn, LogLevel: Warn,
IgnoreRecordNotFoundError: false, IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
) )
// New initialize logger
func New(writer Writer, config Config) Interface { func New(writer Writer, config Config) Interface {
var ( var (
infoStr = "%s\n[info] " infoStr = "%s\n[info] "
@ -134,30 +118,29 @@ func (l *logger) LogMode(level LogLevel) Interface {
} }
// Info print info // Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info { if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Warn print warn messages // Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn { if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Error print error messages // Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error { if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Trace print sql message // Trace print sql message
// func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent { if l.LogLevel <= Silent {
return return
} }
@ -189,14 +172,6 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
} }
} }
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct { type traceRecorder struct {
Interface Interface
BeginAt time.Time BeginAt time.Time
@ -205,21 +180,12 @@ type traceRecorder struct {
Err error Err error
} }
// New trace recorder func (l traceRecorder) New() *traceRecorder {
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
} }
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin l.BeginAt = begin
l.SQL, l.RowsAffected = fc() l.SQL, l.RowsAffected = fc()
l.Err = err l.Err = err
} }
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -19,40 +19,20 @@ const (
nullStr = "NULL" nullStr = "NULL"
) )
func isPrintable(s string) bool { func isPrintable(s []byte) bool {
for _, r := range s { for _, r := range s {
if !unicode.IsPrint(r) { if !unicode.IsPrint(rune(r)) {
return false return false
} }
} }
return true return true
} }
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var ( var convertParams func(interface{}, int)
convertParams func(interface{}, int) var vars = make([]string, len(avars))
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) { convertParams = func(v interface{}, idx int) {
switch v := v.(type) { switch v := v.(type) {
@ -84,36 +64,23 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
} }
case fmt.Stringer: case fmt.Stringer:
reflectValue := reflect.ValueOf(v) reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() { if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) } else {
case reflect.Float32, reflect.Float64: vars[idx] = nullStr
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
} }
case []byte: case []byte:
if s := string(v); isPrintable(s) { if isPrintable(v) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
} else { } else {
vars[idx] = escaper + "<binary>" + escaper vars[idx] = escaper + "<binary>" + escaper
} }
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v) vars[idx] = utils.ToString(v)
case float32: case float64, float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) vars[idx] = fmt.Sprintf("%.6f", v)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string: case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
default: default:
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -123,12 +90,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx) convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx) convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else { } else {
for _, t := range convertibleTypes { for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {
@ -136,7 +97,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return return
} }
} }
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
} }
} }
} }
@ -163,18 +124,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
sql = newSQL.String() sql = newSQL.String()
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
num := v[1 : len(v)-1] }
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
} }
return sql return sql

View File

@ -31,24 +31,20 @@ func (s ExampleStruct) Value() (driver.Value, error) {
} }
func format(v []byte, escaper string) string { func format(v []byte, escaper string) string {
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
} }
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {
type role string type role string
type password []byte type password []byte
type intType int
type floatType float64
var ( var (
tt = now.MustParse("2020-02-23 11:10:10") tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin") myrole = role("admin")
pwd = password("pass") pwd = password([]byte("pass"))
jsVal = []byte(`{"Name":"test","Val":"test"}`) jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal) js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`) esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"} es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
) )
results := []struct { results := []struct {
@ -61,67 +57,43 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile(`\$(\d+)`), NumericRegexp: regexp.MustCompile(`\$(\d+)`),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`), NumericRegexp: regexp.MustCompile(`@p(\d+)`),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
}, },
{ {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil, NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
}, },
} }

View File

@ -1,8 +1,6 @@
package gorm package gorm
import ( import (
"reflect"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -13,7 +11,11 @@ func (db *DB) Migrator() Migrator {
// apply scopes to migrator // apply scopes to migrator
for len(tx.Statement.scopes) > 0 { for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes() scopes := tx.Statement.scopes
tx.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
}
} }
return tx.Dialector.Migrator(tx.Session(&Session{})) return tx.Dialector.Migrator(tx.Session(&Session{}))
@ -26,45 +28,19 @@ func (db *DB) AutoMigrate(dst ...interface{}) error {
// ViewOption view option // ViewOption view option
type ViewOption struct { type ViewOption struct {
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` Replace bool
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` CheckOption string
Query *DB // required subquery. Query *DB
} }
// ColumnType column type interface
type ColumnType interface { type ColumnType interface {
Name() string Name() string
DatabaseTypeName() string // varchar DatabaseTypeName() string
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool) Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool) Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
} }
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface { type Migrator interface {
// AutoMigrate // AutoMigrate
AutoMigrate(dst ...interface{}) error AutoMigrate(dst ...interface{}) error
@ -72,23 +48,18 @@ type Migrator interface {
// Database // Database
CurrentDatabase() string CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables // Tables
CreateTable(dst ...interface{}) error CreateTable(dst ...interface{}) error
DropTable(dst ...interface{}) error DropTable(dst ...interface{}) error
HasTable(dst interface{}) bool HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns // Columns
AddColumn(dst interface{}, field string) error AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error) ColumnTypes(dst interface{}) ([]ColumnType, error)
@ -107,5 +78,4 @@ type Migrator interface {
DropIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
} }

View File

@ -1,107 +0,0 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

View File

@ -1,43 +0,0 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

View File

@ -3,32 +3,20 @@ package migrator
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), var (
// with a possible trailing non-digit character (\D?). regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
// For example, values that can pass this regular expression are: )
// - "123"
// - "abc456"
// -"%$#@789"
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct // Migrator m struct
type Migrator struct { type Migrator struct {
@ -42,22 +30,10 @@ type Config struct {
gorm.Dialector gorm.Dialector
} }
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string GormDBDataType(*gorm.DB, *schema.Field) string
} }
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB} stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil { if m.DB.Statement != nil {
@ -67,14 +43,13 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
if table, ok := value.(string); ok { if table, ok := value.(string); ok {
stmt.Table = table stmt.Table = table
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { } else if err := stmt.Parse(value); err != nil {
return err return err
} }
return fc(stmt) return fc(stmt)
} }
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string { func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
@ -86,7 +61,6 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
return m.Dialector.DataTypeOf(field) return m.Dialector.DataTypeOf(field)
} }
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field) expr.SQL = m.DataTypeOf(field)
@ -94,6 +68,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL += " NOT NULL" expr.SQL += " NOT NULL"
} }
if field.Unique {
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil { if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
@ -107,44 +85,23 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return return
} }
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { // AutoMigrate
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) { for _, value := range m.ReorderModels(values, true) {
queryTx, execTx := m.GetQueryAndExecTx() tx := m.DB.Session(&gorm.Session{})
if !queryTx.Migrator().HasTable(value) { if !tx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil { if err := tx.Migrator().CreateTable(value); err != nil {
return err return err
} }
} else { } else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
if stmt.Schema == nil { for _, field := range stmt.Schema.FieldsByDBName {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
}
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames {
var foundColumn gorm.ColumnType var foundColumn gorm.ColumnType
for _, columnType := range columnTypes { for _, columnType := range columnTypes {
if columnType.Name() == dbName { if columnType.Name() == field.DBName {
foundColumn = columnType foundColumn = columnType
break break
} }
@ -152,43 +109,37 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
if foundColumn == nil { if foundColumn == nil {
// not found, add column // not found, add column
if err = execTx.Migrator().AddColumn(value, dbName); err != nil { if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err return err
} }
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
} }
} }
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations {
for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil && if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err return err
} }
} }
} }
} }
for _, chk := range parseCheckConstraints { for _, idx := range stmt.Schema.ParseIndexes() {
if !queryTx.Migrator().HasConstraint(value, chk.Name) { if !tx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
for _, idx := range parseIndexes {
if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err return err
} }
} }
@ -204,23 +155,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
return nil return nil
} }
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var ( var (
createTableSQL = "CREATE TABLE ? (" createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)} values = []interface{}{m.CurrentTable(stmt)}
@ -231,7 +169,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
field := stmt.Schema.FieldsByDBName[dbName] field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration { if !field.IgnoreMigration {
createTableSQL += "? ?" createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += "," createTableSQL += ","
} }
@ -239,7 +177,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?," createTableSQL += "PRIMARY KEY ?,"
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
} }
@ -250,8 +188,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable { if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) { defer func(value interface{}, name string) {
if err == nil { if errr == nil {
err = tx.Migrator().CreateIndex(value, name) errr = tx.Migrator().CreateIndex(value, name)
} }
}(value, idx.Name) }(value, idx.Name)
} else { } else {
@ -269,18 +207,15 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
createTableSQL += "," createTableSQL += ","
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
} }
} }
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations {
for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.DisableForeignKeyConstraintWhenMigrating {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil { if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema { if constraint.Schema == stmt.Schema {
sql, vars := constraint.Build() sql, vars := buildConstraint(constraint)
createTableSQL += sql + "," createTableSQL += sql + ","
values = append(values, vars...) values = append(values, vars...)
} }
@ -288,11 +223,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
} }
} }
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?)," createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
@ -306,8 +236,8 @@ func (m Migrator) CreateTable(values ...interface{}) error {
createTableSQL += fmt.Sprint(tableOption) createTableSQL += fmt.Sprint(tableOption)
} }
err = tx.Exec(createTableSQL, values...).Error errr = tx.Exec(createTableSQL, values...).Error
return err return errr
}); err != nil { }); err != nil {
return err return err
} }
@ -315,7 +245,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
return nil return nil
} }
// DropTable drop table for values
func (m Migrator) DropTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false) values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- { for i := len(values) - 1; i >= 0; i-- {
@ -329,7 +258,6 @@ func (m Migrator) DropTable(values ...interface{}) error {
return nil return nil
} }
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value interface{}) bool { func (m Migrator) HasTable(value interface{}) bool {
var count int64 var count int64
@ -341,7 +269,6 @@ func (m Migrator) HasTable(value interface{}) bool {
return count > 0 return count > 0
} }
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{} var oldTable, newTable interface{}
if v, ok := oldName.(string); ok { if v, ok := oldName.(string); ok {
@ -369,16 +296,12 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
} }
// AddColumn create `name` column for value func (m Migrator) AddColumn(value interface{}, field string) error {
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
if stmt.Schema == nil { f := stmt.Schema.LookUpField(field)
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name)
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name) return fmt.Errorf("failed to look up field with name: %s", field)
} }
if !f.IgnoreMigration { if !f.IgnoreMigration {
@ -392,13 +315,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
}) })
} }
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil {
if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName
name = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -407,33 +327,27 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
}) })
} }
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil {
if field := stmt.Schema.LookUpField(field); field != nil { fileType := clause.Expr{SQL: m.DataTypeOf(field)}
fileType := m.FullDataTypeOf(field) return m.DB.Exec(
return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, ).Error
).Error
}
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
} }
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value interface{}, field string) bool { func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field name := field
if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil {
if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName
name = field.DBName
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -445,17 +359,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
return count > 0 return count > 0
} }
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil { if field := stmt.Schema.LookUpField(oldName); field != nil {
if field := stmt.Schema.LookUpField(oldName); field != nil { oldName = field.DBName
oldName = field.DBName }
}
if field := stmt.Schema.LookUpField(newName); field != nil { if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName newName = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -465,160 +376,61 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
}) })
} }
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate // found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
)
if !field.PrimaryKey { alterColumn := false
// check type
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
if !isSameType { // check size
if length, _ := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
alterColumn = true alterColumn = true
} }
} }
} }
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
}
}
}
// check precision // check precision
if realDataType == "decimal" || realDataType == "numeric" && if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
precision, scale, ok := columnType.DecimalSize() alterColumn = true
if ok {
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
alterColumn = true
}
}
} else {
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
} }
} }
// check nullable // check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & current database is non-nullable(to be nullable) // not primary key & database is nullable
if !field.PrimaryKey && !nullable { if !field.PrimaryKey && nullable {
alterColumn = true alterColumn = true
} }
} }
// check default value if alterColumn && !field.IgnoreMigration {
if !field.PrimaryKey { return m.DB.Migrator().AlterColumn(value, field.Name)
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// default value -> null
alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
}
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn {
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
} }
return nil return nil
} }
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0) columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil { if err != nil {
return err return err
} }
defer func() { defer rows.Close()
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes() rawColumnTypes, err = rows.ColumnTypes()
@ -627,85 +439,53 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
} }
for _, c := range rawColumnTypes { for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) columnTypes = append(columnTypes, c)
} }
return return nil
}) })
return columnTypes, execErr return columnTypes, execErr
} }
// CreateView create view from Query in gorm.ViewOption.
// Query in gorm.ViewOption is a [subquery]
//
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
// q := DB.Model(&User{}).Where("age > ?", 20)
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
//
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
// q := DB.Model(&User{})
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
//
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
func (m Migrator) CreateView(name string, option gorm.ViewOption) error { func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
if option.Query == nil { return gorm.ErrNotImplemented
return gorm.ErrSubQueryRequired
}
sql := new(strings.Builder)
sql.WriteString("CREATE ")
if option.Replace {
sql.WriteString("OR REPLACE ")
}
sql.WriteString("VIEW ")
m.QuoteTo(sql, name)
sql.WriteString(" AS ")
m.DB.Statement.AddVar(sql, option.Query)
if option.CheckOption != "" {
sql.WriteString(" ")
sql.WriteString(option.CheckOption)
}
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
} }
// DropView drop view
func (m Migrator) DropView(name string) error { func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error return gorm.ErrNotImplemented
} }
// GuessConstraintAndTable guess statement's constraint and it's table based on name func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
// sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
// Deprecated: use GuessConstraintInterfaceAndTable instead. if constraint.OnDelete != "" {
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { sql += " ON DELETE " + constraint.OnDelete
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
} }
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
} }
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil { if stmt.Schema == nil {
return nil, stmt.Table return nil, nil, stmt.Table
} }
checkConstraints := stmt.Schema.ParseCheckConstraints() checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok { if chk, ok := checkConstraints[name]; ok {
return &chk, stmt.Table return nil, &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
} }
getTable := func(rel *schema.Relationship) string { getTable := func(rel *schema.Relationship) string {
@ -720,7 +500,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, getTable(rel) return constraint, nil, getTable(rel)
} }
} }
@ -728,62 +508,64 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for k := range checkConstraints { for k := range checkConstraints {
if checkConstraints[k].Field == field { if checkConstraints[k].Field == field {
v := checkConstraints[k] v := checkConstraints[k]
return &v, stmt.Table return nil, &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
} }
} }
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, getTable(rel) return constraint, nil, getTable(rel)
} }
} }
} }
return nil, stmt.Schema.Table return nil, nil, stmt.Schema.Table
} }
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error { func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if chk != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
if constraint != nil { if constraint != nil {
vars := []interface{}{clause.Table{Name: table}} var vars = []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil { if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr vars[0] = stmt.TableExpr
} }
sql, values := constraint.Build() sql, values := buildConstraint(constraint)
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
} }
return nil return nil
}) })
} }
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error { func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.GetName() name = constraint.Name
} else if chk != nil {
name = chk.Name
} }
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
}) })
} }
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value interface{}, name string) bool { func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil { if constraint != nil {
name = constraint.GetName() name = constraint.Name
} else if chk != nil {
name = chk.Name
} }
return m.DB.Raw( return m.DB.Raw(
@ -795,7 +577,6 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0 return count > 0
} }
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts { for _, opt := range opts {
str := stmt.Quote(opt.DBName) str := stmt.Quote(opt.DBName)
@ -817,17 +598,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
return return
} }
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface { type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
} }
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@ -857,28 +633,22 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
}) })
} }
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
}) })
} }
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value interface{}, name string) bool { func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name
name = idx.Name
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -890,7 +660,6 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
return count > 0 return count > 0
} }
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec( return m.DB.Exec(
@ -900,7 +669,6 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
}) })
} }
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) { func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return return
@ -927,8 +695,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
Statement: &gorm.Statement{DB: m.DB, Dest: value}, Statement: &gorm.Statement{DB: m.DB, Dest: value},
} }
beDependedOn := map[*schema.Schema]bool{} beDependedOn := map[*schema.Schema]bool{}
// support for special table name if err := dep.Parse(value); err != nil {
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
} }
if _, ok := parsedSchemas[dep.Statement.Schema]; ok { if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
@ -936,31 +703,26 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
} }
parsedSchemas[dep.Statement.Schema] = true parsedSchemas[dep.Statement.Schema] = true
if !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range dep.Schema.Relationships.Relations {
for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
if rel.Field.IgnoreMigration { dep.Depends = append(dep.Depends, c.ReferenceSchema)
continue }
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.Type == schema.HasOne || rel.Type == schema.HasMany { if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true beDependedOn[rel.FieldSchema] = true
} }
if rel.JoinTable != nil { if rel.JoinTable != nil {
// append join value // append join value
defer func(rel *schema.Relationship, joinValue interface{}) { defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] { if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema) dep.Depends = append(dep.Depends, rel.FieldSchema)
} else { } else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd) parseDependence(fieldValue, autoAdd)
} }
parseDependence(joinValue, autoAdd) parseDependence(joinValue, autoAdd)
}(rel, reflect.New(rel.JoinTable.ModelType).Interface()) }(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
} }
} }
@ -1010,25 +772,9 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
return return
} }
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil { if stmt.TableExpr != nil {
return *stmt.TableExpr return *stmt.TableExpr
} }
return clause.Table{Name: stmt.Table} return clause.Table{Name: stmt.Table}
} }
// GetIndexes return Indexes []gorm.Index and execErr error
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}
// TableType return tableType gorm.TableType and execErr error
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
return nil, errors.New("not support")
}

View File

@ -1,33 +0,0 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

View File

@ -4,10 +4,9 @@ import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it // It may be embedded into your model or you may build your own model without it
// // type User struct {
// type User struct { // gorm.Model
// gorm.Model // }
// }
type Model struct { type Model struct {
ID uint `gorm:"primarykey"` ID uint `gorm:"primarykey"`
CreatedAt time.Time CreatedAt time.Time

View File

@ -3,86 +3,70 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync" "sync"
"time"
"gorm.io/gorm/internal/stmt_store"
) )
type Stmt struct {
*sql.Stmt
Transaction bool
}
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts stmt_store.Store Stmts map[string]Stmt
Mux *sync.RWMutex PreparedSQL []string
Mux *sync.RWMutex
ConnPool ConnPool
} }
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn() return dbConnector.GetDBConn()
} }
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
return nil, ErrInvalidDB return nil, ErrInvalidDB
} }
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() defer db.Mux.Unlock()
for _, key := range db.Stmts.Keys() { for _, query := range db.PreparedSQL {
db.Stmts.Delete(key) if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
} }
} }
// Reset Deprecated use Close instead func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock() db.Mux.RLock()
if db.Stmts != nil { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock()
db.Mux.RUnlock() return stmt, nil
return stmt, stmt.Error()
}
} }
db.Mux.RUnlock() db.Mux.RUnlock()
// retry
db.Mux.Lock() db.Mux.Lock()
if db.Stmts != nil { defer db.Mux.Unlock()
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock() // double check
return stmt, stmt.Error() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
} return stmt, nil
} else if ok {
go stmt.Close()
} }
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux) stmt, err := conn.PrepareContext(ctx, query)
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
}
return db.Stmts[query], err
} }
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -90,19 +74,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt) tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} }
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }
@ -110,8 +81,11 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if err != nil {
db.Stmts.Delete(query) db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
} }
} }
return result, err return result, err
@ -121,8 +95,12 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if err != nil {
db.Stmts.Delete(query) db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
} }
} }
return rows, err return rows, err
@ -136,32 +114,20 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{} return &sql.Row{}
} }
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct { type PreparedStmtTX struct {
Tx *sql.Tx
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error { func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { if tx.Tx != nil {
return tx.Tx.Commit() return tx.Tx.Commit()
} }
return ErrInvalidTransaction return ErrInvalidTransaction
} }
func (tx *PreparedStmtTX) Rollback() error { func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { if tx.Tx != nil {
return tx.Tx.Rollback() return tx.Tx.Rollback()
} }
return ErrInvalidTransaction return ErrInvalidTransaction
@ -171,8 +137,12 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if err != nil {
tx.PreparedStmtDB.Stmts.Delete(query) tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
} }
} }
return result, err return result, err
@ -181,9 +151,13 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) { if err != nil {
tx.PreparedStmtDB.Stmts.Delete(query) tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
} }
} }
return rows, err return rows, err
@ -196,11 +170,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
} }
return &sql.Row{} return &sql.Row{}
} }
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

372
scan.go
View File

@ -8,15 +8,13 @@ import (
"time" "time"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
) )
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
for idx, name := range columns { for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil { if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface() values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue continue
} }
values[idx] = new(interface{}) values[idx] = new(interface{})
@ -24,7 +22,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
} else if len(columnTypes) > 0 { } else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes { for idx, columnType := range columnTypes {
if columnType.ScanType() != nil { if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface() values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else { } else {
values[idx] = new(interface{}) values[idx] = new(interface{})
} }
@ -51,96 +49,9 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
} }
} }
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { func Scan(rows *sql.Rows, db *DB, initialized bool) {
for idx, field := range fields { columns, _ := rows.Columns()
if field != nil { values := make([]interface{}, len(columns))
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0 db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
@ -155,9 +66,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
mapValue, ok := dest.(map[string]interface{}) mapValue, ok := dest.(map[string]interface{})
if !ok { if !ok {
if v, ok := dest.(*map[string]interface{}); ok { if v, ok := dest.(*map[string]interface{}); ok {
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v mapValue = *v
} }
} }
@ -188,171 +96,155 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))
} }
default: default:
var ( Schema := db.Statement.Schema
fields = make([]*schema.Field, len(columns)) reflectValue := db.Statement.ReflectValue
joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface { if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem() reflectValue = reflectValue.Elem()
} }
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][]*schema.Field, len(columns))
}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue
}
}
var val interface{}
values[idx] = &val
} else {
var val interface{}
values[idx] = &val
}
}
}
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
var ( var (
elem reflect.Value reflectValueType = reflectValue.Type().Elem()
isArrayKind = reflectValue.Kind() == reflect.Array isPtr = reflectValueType.Kind() == reflect.Ptr
fields = make([]*schema.Field, len(columns))
joinFields [][2]*schema.Field
) )
if !update || reflectValue.Len() == 0 { if isPtr {
update = false reflectValueType = reflectValueType.Elem()
if isArrayKind { }
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 { if Schema != nil {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else { } else {
reflectValue.SetLen(0) values[idx] = &sql.RawBytes{}
db.Statement.ReflectValue.Set(reflectValue)
} }
} }
} }
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
isPluck = true
}
}
for initialized || rows.Next() { for initialized || rows.Next() {
BEGIN:
initialized = false initialized = false
db.RowsAffected++
if update { elem := reflect.New(reflectValueType)
if int(db.RowsAffected) >= reflectValue.Len() { if isPluck {
return db.AddError(rows.Scan(elem.Interface()))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
}
} }
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing { db.AddError(rows.Scan(values...))
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { for idx, field := range fields {
db.RowsAffected++ if len(joinFields) != 0 && joinFields[idx][0] != nil {
goto BEGIN value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
} else {
reflectValue = reflect.Append(reflectValue, elem.Elem())
}
}
db.Statement.ReflectValue.Set(reflectValue)
case reflect.Struct, reflect.Ptr:
if reflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else if len(columns) == 1 {
values[idx] = dest
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
field.Set(reflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
relValue := rel.Field.ReflectValueOf(reflectValue)
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
field.Set(relValue, values[idx])
} }
} }
} }
} else {
elem = reflect.New(reflectValueType)
} }
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
}
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
} }
default: default:
db.AddError(rows.Scan(dest)) db.AddError(rows.Scan(dest))

View File

@ -9,7 +9,8 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
type UserWithCallback struct{} type UserWithCallback struct {
}
func (UserWithCallback) BeforeSave(*gorm.DB) error { func (UserWithCallback) BeforeSave(*gorm.DB) error {
return nil return nil

37
schema/check.go Normal file
View File

@ -0,0 +1,37 @@
package schema
import (
"regexp"
"strings"
)
var (
// reg match english letters and midline
regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
)
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
var checks = map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

View File

@ -6,7 +6,6 @@ import (
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserCheck struct { type UserCheck struct {
@ -21,7 +20,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err) t.Fatalf("failed to parse user check, got error %v", err)
} }
results := map[string]schema.CheckConstraint{ results := map[string]schema.Check{
"name_checker": { "name_checker": {
Name: "name_checker", Name: "name_checker",
Constraint: "name <> 'jinzhu'", Constraint: "name <> 'jinzhu'",
@ -54,31 +53,3 @@ func TestParseCheck(t *testing.T) {
} }
} }
} }
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

View File

@ -1,66 +0,0 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
package schema_test package schema_test
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"sync" "sync"
@ -58,7 +57,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -81,7 +80,7 @@ func TestFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -133,7 +132,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -152,7 +151,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -203,7 +202,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -220,7 +219,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
} }
for k, v := range newValues2 { for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
@ -245,7 +244,7 @@ func TestParseFieldWithPermission(t *testing.T) {
t.Fatalf("Failed to parse user with permission, got error %v", err) t.Fatalf("Failed to parse user with permission, got error %v", err)
} }
fields := []*schema.Field{ fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
@ -258,77 +257,6 @@ func TestParseFieldWithPermission(t *testing.T) {
} }
for _, f := range fields { for _, f := range fields {
checkSchemaField(t, user, f, func(f *schema.Field) {}) checkSchemaField(t, user, &f, func(f *schema.Field) {})
}
}
type (
ID int64
INT int
INT8 int8
INT16 int16
INT32 int32
INT64 int64
UINT uint
UINT8 uint8
UINT16 uint16
UINT32 uint32
UINT64 uint64
FLOAT32 float32
FLOAT64 float64
BOOL bool
STRING string
TIME time.Time
BYTES []byte
TypeAlias struct {
ID
INT `gorm:"column:fint"`
INT8 `gorm:"column:fint8"`
INT16 `gorm:"column:fint16"`
INT32 `gorm:"column:fint32"`
INT64 `gorm:"column:fint64"`
UINT `gorm:"column:fuint"`
UINT8 `gorm:"column:fuint8"`
UINT16 `gorm:"column:fuint16"`
UINT32 `gorm:"column:fuint32"`
UINT64 `gorm:"column:fuint64"`
FLOAT32 `gorm:"column:ffloat32"`
FLOAT64 `gorm:"column:ffloat64"`
BOOL `gorm:"column:fbool"`
STRING `gorm:"column:fstring"`
TIME `gorm:"column:ftime"`
BYTES `gorm:"column:fbytes"`
}
)
func TestTypeAliasField(t *testing.T) {
alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
}
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`},
{Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`},
{Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`},
{Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`},
{Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`},
{Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`},
{Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`},
{Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`},
{Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`},
{Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`},
{Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`},
{Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`},
{Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`},
{Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`},
{Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`},
{Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`},
}
for _, f := range fields {
checkSchemaField(t, alias, f, func(f *schema.Field) {})
} }
} }

View File

@ -1,7 +1,6 @@
package schema package schema
import ( import (
"fmt"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -13,8 +12,8 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin Type string // btree, hash, gist, spgist, gin, and brin
Where string Where string
Comment string Comment string
Option string // WITH PARSER parser_name Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same Fields []IndexOption
} }
type IndexOption struct { type IndexOption struct {
@ -23,28 +22,17 @@ type IndexOption struct {
Sort string // DESC, ASC Sort string // DESC, ASC
Collate string Collate string
Length int Length int
Priority int priority int
} }
// ParseIndexes parse schema indexes // ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() []*Index { func (schema *Schema) ParseIndexes() map[string]Index {
indexesByName := map[string]*Index{} var indexes = map[string]Index{}
indexes := []*Index{}
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
fieldIndexes, err := parseFieldIndexes(field) for _, index := range parseFieldIndexes(field) {
if err != nil { idx := indexes[index.Name]
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
idx.Name = index.Name idx.Name = index.Name
if idx.Class == "" { if idx.Class == "" {
idx.Class = index.Class idx.Class = index.Class
@ -64,16 +52,14 @@ func (schema *Schema) ParseIndexes() []*Index {
idx.Fields = append(idx.Fields, index.Fields...) idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool { sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].Priority < idx.Fields[j].Priority return idx.Fields[i].priority < idx.Fields[j].priority
}) })
indexes[index.Name] = idx
} }
} }
} }
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes return indexes
} }
@ -82,12 +68,12 @@ func (schema *Schema) LookIndex(name string) *Index {
indexes := schema.ParseIndexes() indexes := schema.ParseIndexes()
for _, index := range indexes { for _, index := range indexes {
if index.Name == name { if index.Name == name {
return index return &index
} }
for _, field := range index.Fields { for _, field := range index.Fields {
if field.Name == name { if field.Name == name {
return index return &index
} }
} }
} }
@ -96,41 +82,30 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil return nil
} }
func parseFieldIndexes(field *Field) (indexes []Index, err error) { func parseFieldIndexes(field *Field) (indexes []Index) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" { if value != "" {
v := strings.Split(value, ":") v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0])) k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" { if k == "INDEX" || k == "UNIQUEINDEX" {
var ( var (
name string name string
tag = strings.Join(v[1:], ":") tag = strings.Join(v[1:], ":")
idx = strings.IndexByte(tag, ',') idx = strings.Index(tag, ",")
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") settings = ParseTagSetting(tag, ",")
settings = ParseTagSetting(tagSetting, ",") length, _ = strconv.Atoi(settings["LENGTH"])
length, _ = strconv.Atoi(settings["LENGTH"])
) )
if idx == -1 { if idx == -1 {
idx = len(tag) idx = len(tag)
} }
name = tag[0:idx] if idx != -1 {
name = tag[0:idx]
}
if name == "" { if name == "" {
subName := field.Name name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
} }
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
@ -155,13 +130,12 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
Sort: settings["SORT"], Sort: settings["SORT"],
Collate: settings["COLLATE"], Collate: settings["COLLATE"],
Length: length, Length: length,
Priority: priority, priority: priority,
}}, }},
}) })
} }
} }
} }
err = nil
return return
} }

View File

@ -1,11 +1,11 @@
package schema_test package schema_test
import ( import (
"reflect"
"sync" "sync"
"testing" "testing"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
) )
type UserIndex struct { type UserIndex struct {
@ -18,41 +18,6 @@ type UserIndex struct {
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"` MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
} }
func TestParseIndex(t *testing.T) { func TestParseIndex(t *testing.T) {
@ -61,17 +26,17 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err) t.Fatalf("failed to parse user index, got error %v", err)
} }
results := []*schema.Index{ results := map[string]schema.Index{
{ "idx_user_indices_name": {
Name: "idx_user_indices_name", Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
}, },
{ "idx_name": {
Name: "idx_name", Name: "idx_name",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2"}}},
}, },
{ "idx_user_indices_name3": {
Name: "idx_user_indices_name3", Name: "idx_user_indices_name3",
Type: "btree", Type: "btree",
Where: "name3 != 'jinzhu'", Where: "name3 != 'jinzhu'",
@ -82,19 +47,19 @@ func TestParseIndex(t *testing.T) {
Length: 10, Length: 10,
}}, }},
}, },
{ "idx_user_indices_name4": {
Name: "idx_user_indices_name4", Name: "idx_user_indices_name4",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4"}}},
}, },
{ "idx_user_indices_name5": {
Name: "idx_user_indices_name5", Name: "idx_user_indices_name5",
Class: "FULLTEXT", Class: "FULLTEXT",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
}, },
{ "profile": {
Name: "profile", Name: "profile",
Comment: "hello , world", Comment: "hello , world",
Where: "age > 10", Where: "age > 10",
@ -104,172 +69,48 @@ func TestParseIndex(t *testing.T) {
Expression: "ABS(age)", Expression: "ABS(age)",
}}, }},
}, },
{ "idx_id": {
Name: "idx_id", Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID"}}},
}, },
{ "idx_oid": {
Name: "idx_oid", Name: "idx_oid",
Class: "UNIQUE", Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID"}}},
},
{
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
}, },
} }
CheckIndices(t, results, user.ParseIndexes()) indices := user.ParseIndexes()
}
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { for k, result := range results {
type IndexTest struct { v, ok := indices[k]
FieldA string `gorm:"unique;index"` // unique and index if !ok {
FieldB string `gorm:"unique"` // unique t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
}
FieldC string `gorm:"index:,unique"` // uniqueIndex for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
t.Errorf(
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex "index %v %v should equal, expects %v, got %v",
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"` k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
}
for i, ei := range expected {
t.Run(ei.Name, func(t *testing.T) {
ai := actual[i]
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
if len(ei.Fields) != len(ai.Fields) {
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
return
} }
for i, ef := range ei.Fields { }
af := ai.Fields[i]
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull") for idx, ef := range result.Fields {
rf := v.Fields[idx]
if rf.Field.Name != ef.Field.Name {
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
} }
})
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
}
}
} }
} }

View File

@ -4,39 +4,22 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface { type GormDataTypeInterface interface {
GormDataType() string GormDataType() string
} }
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface { type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface CreateClauses(*Field) []clause.Interface
} }
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface { type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface QueryClauses(*Field) []clause.Interface
} }
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface { type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface UpdateClauses(*Field) []clause.Interface
} }
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface { type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface DeleteClauses(*Field) []clause.Interface
} }

View File

@ -26,11 +26,9 @@ type User struct {
Active *bool Active *bool
} }
type ( type mytime time.Time
mytime time.Time type myint int
myint int type mybool = bool
mybool = bool
)
type AdvancedDataTypeUser struct { type AdvancedDataTypeUser struct {
ID sql.NullInt64 ID sql.NullInt64

View File

@ -3,13 +3,12 @@ package schema
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"fmt"
"regexp" "regexp"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
) )
// Namer namer interface // Namer namer interface
@ -21,7 +20,6 @@ type Namer interface {
RelationshipFKName(Relationship) string RelationshipFKName(Relationship) string
CheckerName(table, column string) string CheckerName(table, column string) string
IndexName(table, column string) string IndexName(table, column string) string
UniqueName(table, column string) string
} }
// Replacer replacer interface like strings.Replacer // Replacer replacer interface like strings.Replacer
@ -29,15 +27,12 @@ type Replacer interface {
Replace(name string) string Replace(name string) string
} }
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
type NamingStrategy struct { type NamingStrategy struct {
TablePrefix string TablePrefix string
SingularTable bool SingularTable bool
NameReplacer Replacer NameReplacer Replacer
NoLowerCase bool NoLowerCase bool
IdentifierMaxLength int
} }
// TableName convert string to table name // TableName convert string to table name
@ -90,26 +85,17 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column)) return ns.formatName("idx", table, ns.toDBName(column))
} }
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string { func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{ formattedName := strings.Replace(strings.Join([]string{
prefix, table, name, prefix, table, name,
}, "_"), ".", "_") }, "_"), ".", "_", -1)
if ns.IdentifierMaxLength == 0 { if utf8.RuneCountInString(formattedName) > 64 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New() h := sha1.New()
h.Write([]byte(formattedName)) h.Write([]byte(formattedName))
bs := h.Sum(nil) bs := h.Sum(nil)
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8]
} }
return formattedName return formattedName
} }
@ -123,7 +109,7 @@ var (
func init() { func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism)) commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
} }
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
} }
@ -134,13 +120,7 @@ func (ns NamingStrategy) toDBName(name string) string {
} }
if ns.NameReplacer != nil { if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name) name = ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
} }
if ns.NoLowerCase { if ns.NoLowerCase {
@ -188,9 +168,9 @@ func (ns NamingStrategy) toDBName(name string) string {
} }
func (ns NamingStrategy) toSchemaName(name string) string { func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "") result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
for _, initialism := range commonInitialisms { for _, initialism := range commonInitialisms {
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
} }
return result return result
} }

View File

@ -6,7 +6,7 @@ import (
) )
func TestToDBName(t *testing.T) { func TestToDBName(t *testing.T) {
maps := map[string]string{ var maps = map[string]string{
"": "", "": "",
"x": "x", "x": "x",
"X": "x", "X": "x",
@ -56,7 +56,7 @@ func TestToDBName(t *testing.T) {
} }
func TestNamingStrategy(t *testing.T) { func TestNamingStrategy(t *testing.T) {
ns := NamingStrategy{ var ns = NamingStrategy{
TablePrefix: "public.", TablePrefix: "public.",
SingularTable: true, SingularTable: true,
NameReplacer: strings.NewReplacer("CID", "Cid"), NameReplacer: strings.NewReplacer("CID", "Cid"),
@ -102,7 +102,7 @@ func (r CustomReplacer) Replace(name string) string {
} }
func TestCustomReplacer(t *testing.T) { func TestCustomReplacer(t *testing.T) {
ns := NamingStrategy{ var ns = NamingStrategy{
TablePrefix: "public.", TablePrefix: "public.",
SingularTable: true, SingularTable: true,
NameReplacer: CustomReplacer{ NameReplacer: CustomReplacer{
@ -146,7 +146,7 @@ func TestCustomReplacer(t *testing.T) {
} }
func TestCustomReplacerWithNoLowerCase(t *testing.T) { func TestCustomReplacerWithNoLowerCase(t *testing.T) {
ns := NamingStrategy{ var ns = NamingStrategy{
TablePrefix: "public.", TablePrefix: "public.",
SingularTable: true, SingularTable: true,
NameReplacer: CustomReplacer{ NameReplacer: CustomReplacer{
@ -189,31 +189,11 @@ func TestCustomReplacerWithNoLowerCase(t *testing.T) {
} }
} }
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 64} var ns = NamingStrategy{}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName) t.Errorf("invalid formatted name generated, got %v", formattedName)
} }
} }
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

View File

@ -1,19 +0,0 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,16 +1,11 @@
package schema package schema
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/jinzhu/inflection" "github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@ -31,10 +26,6 @@ type Relationships struct {
HasMany []*Relationship HasMany []*Relationship
Many2Many []*Relationship Many2Many []*Relationship
Relations map[string]*Relationship Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
} }
type Relationship struct { type Relationship struct {
@ -78,12 +69,12 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) schema.err = err
return nil return nil
} }
if hasPolymorphicRelation(field.TagSettings) { if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field) schema.buildPolymorphicRelation(relation, field, polymorphic)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
@ -95,16 +86,14 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice: case reflect.Slice:
schema.guessRelation(relation, field, guessHas) schema.guessRelation(relation, field, guessHas)
default: default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
field.Name)
} }
} }
if relation.Type == has { if relation.Type == has {
// don't add relations to embedded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
} }
switch field.IndirectFieldType.Kind() { switch field.IndirectFieldType.Kind() {
@ -116,7 +105,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
} }
if schema.err == nil { if schema.err == nil {
schema.setRelation(relation) schema.Relationships.Relations[relation.Name] = relation
switch relation.Type { switch relation.Type {
case HasOne: case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -132,100 +121,34 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation return relation
} }
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
// // type User struct {
// type User struct { // Toys []Toy `gorm:"polymorphic:Owner;"`
// Toys []Toy `gorm:"polymorphic:Owner;"` // }
// } // type Pet struct {
// type Pet struct { // Toy Toy `gorm:"polymorphic:Owner;"`
// Toy Toy `gorm:"polymorphic:Owner;"` // }
// } // type Toy struct {
// type Toy struct { // OwnerID int
// OwnerID int // OwnerType string
// OwnerType string // }
// } func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{ relation.Polymorphic = &Polymorphic{
Value: schema.Table, Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
} }
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value) relation.Polymorphic.Value = strings.TrimSpace(value)
} }
if relation.Polymorphic.PolymorphicType == nil { if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
} }
if relation.Polymorphic.PolymorphicID == nil { if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
} }
if schema.err == nil { if schema.err == nil {
@ -237,17 +160,10 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
schema, field.Name)
} }
} }
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys // use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) { if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
@ -274,8 +190,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error err error
joinTableFields []reflect.StructField joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{} fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]*Field{} // fix self join many2many ownFieldsMap = map[string]bool{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
) )
@ -308,24 +223,26 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
for idx, ownField := range ownForeignFields { for idx, ownField := range ownForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name joinFieldName := strings.Title(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx { if len(joinForeignKeys) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx]) joinFieldName = strings.Title(joinForeignKeys[idx])
} }
ownFieldsMap[joinFieldName] = ownField ownFieldsMap[joinFieldName] = true
fieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName, Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath, PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type, Type: ownField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
}) })
} }
for idx, relField := range refForeignFields { for idx, relField := range refForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
if _, ok := ownFieldsMap[joinFieldName]; ok { if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name { if field.Name != relation.FieldSchema.Name {
@ -335,32 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
} }
} }
if len(joinReferences) > idx { fieldsMap[joinFieldName] = relField
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx]) joinTableFields = append(joinTableFields, reflect.StructField{
} Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
referFieldsMap[joinFieldName] = relField Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"),
if _, ok := fieldsMap[joinFieldName]; !ok { })
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
} }
joinTableFields = append(joinTableFields, reflect.StructField{ joinTableFields = append(joinTableFields, reflect.StructField{
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name, Name: strings.Title(schema.Name) + field.Name,
Type: schema.ModelType, Type: schema.ModelType,
Tag: `gorm:"-"`, Tag: `gorm:"-"`,
}) })
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.namer); err != nil {
schema.err = err schema.err = err
} }
relation.JoinTable.Name = many2many relation.JoinTable.Name = many2many
@ -407,37 +314,31 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
f.Size = fieldsMap[f.Name].Size f.Size = fieldsMap[f.Name].Size
} }
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if of, ok := ownFieldsMap[f.Name]; ok { if ownPrimaryField {
joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{ joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: of, PrimaryKey: fieldsMap[f.Name],
ForeignKey: f, ForeignKey: f,
}) })
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil { if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field joinRefRel.Field = relation.Field
} }
joinRefRel.References = append(joinRefRel.References, &Reference{ joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: rf, PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f, ForeignKey: f,
}) })
} }
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPrimaryField,
})
} }
} }
} }
@ -479,8 +380,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas) schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas: // case guessEmbeddedHas:
default: default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
schema, field.Name)
} }
} }
@ -488,34 +388,33 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs: case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs: case guessEmbeddedBelongs:
if field.OwnerSchema == nil { if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas: case guessHas:
case guessEmbeddedHas: case guessEmbeddedHas:
if field.OwnerSchema == nil { if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} }
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys { for _, foreignKey := range relation.foreignKeys {
f := foreignSchema.LookUpField(foreignKey) if f := foreignSchema.LookUpField(foreignKey); f != nil {
if f == nil { foreignFields = append(foreignFields, f)
} else {
reguessOrErr() reguessOrErr()
return return
} }
foreignFields = append(foreignFields, f)
} }
} else { } else {
primarySchemaName := primarySchema.Name var primaryFields []*Field
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 { if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys { for _, primaryKey := range relation.primaryKeys {
@ -527,42 +426,31 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
primaryFields = primarySchema.PrimaryFields primaryFields = primarySchema.PrimaryFields
} }
primaryFieldLoop:
for _, primaryField := range primaryFields { for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name lookUpName := primarySchema.Name + primaryField.Name
if gl == guessBelongs { if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name lookUpName = field.Name + primaryField.Name
} }
lookUpNames := []string{lookUpName} lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 { if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
} }
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames { for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil { if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f) foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField) primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop break
} }
} }
} }
} }
switch { if len(foreignFields) == 0 {
case len(foreignFields) == 0:
reguessOrErr() reguessOrErr()
return return
case len(relation.primaryKeys) > 0: } else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil { if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 { if len(primaryFields) < idx+1 {
@ -576,7 +464,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return return
} }
} }
case len(primaryFields) == 0: } else if len(primaryFields) == 0 {
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) { } else if len(primarySchema.PrimaryFields) == len(foreignFields) {
@ -612,7 +500,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} }
// Constraint is ForeignKey Constraint
type Constraint struct { type Constraint struct {
Name string Name string
Field *Field Field *Field
@ -624,31 +511,6 @@ type Constraint struct {
OnUpdate string OnUpdate string
} }
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint { func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"] str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" { if str == "-" {
@ -657,13 +519,12 @@ func (rel *Relationship) ParseConstraint() *Constraint {
if rel.Type == BelongsTo { if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations { for _, r := range rel.FieldSchema.Relationships.Relations {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { if r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true matched := true
for idx, ref := range r.References { for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) { rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false matched = false
break
} }
} }
@ -676,7 +537,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
var ( var (
name string name string
idx = strings.IndexByte(str, ',') idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",") settings = ParseTagSetting(str, ",")
) )
@ -715,7 +576,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
return &constraint return &constraint
} }
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table table := rel.FieldSchema.Table
foreignFields := []*Field{} foreignFields := []*Field{}
relForeignKeys := []string{} relForeignKeys := []string{}
@ -755,7 +616,7 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
} }
} }
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) _, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues) column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: column, Values: values})
@ -763,9 +624,8 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
} }
func copyableDataType(str DataType) bool { func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} { for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(lowerStr, s) { if strings.Contains(strings.ToLower(string(str)), s) {
return false return false
} }
} }

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema, got error %v", err) t.Errorf("Failed to parse schema")
} else { } else {
for _, rel := range relations { for _, rel := range relations {
checkSchemaRelation(t, s, rel) checkSchemaRelation(t, s, rel)
@ -93,20 +93,6 @@ func TestBelongsToWithOnlyReferences2(t *testing.T) {
}) })
} }
func TestSelfReferentialBelongsTo(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatorID *int32
Creator *User
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatorID", "User", "", false}},
})
}
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct { type User struct {
ID int32 `gorm:"primaryKey"` ID int32 `gorm:"primaryKey"`
@ -121,29 +107,6 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
}) })
} }
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) { func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -182,6 +145,7 @@ func TestHasOneOverrideReferences(t *testing.T) {
} }
func TestHasOneOverrideReferences2(t *testing.T) { func TestHasOneOverrideReferences2(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
Name string Name string
@ -328,33 +292,6 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
}) })
} }
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct { type Profile struct {
gorm.Model gorm.Model
@ -541,340 +478,6 @@ func TestEmbeddedRelation(t *testing.T) {
} }
} }
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) { func TestSameForeignKey(t *testing.T) {
type UserAux struct { type UserAux struct {
gorm.Model gorm.Model
@ -900,6 +503,7 @@ func TestSameForeignKey(t *testing.T) {
} }
func TestBelongsToSameForeignKey(t *testing.T) { func TestBelongsToSameForeignKey(t *testing.T) {
type User struct { type User struct {
gorm.Model gorm.Model
Name string Name string
@ -960,39 +564,3 @@ func TestHasManySameForeignKey(t *testing.T) {
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
}) })
} }
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{IdentifierMaxLength: 64},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -5,29 +5,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"path"
"reflect" "reflect"
"strings"
"sync" "sync"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type // ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type") var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -41,7 +25,6 @@ type Schema struct {
PrimaryFieldDBNames []string PrimaryFieldDBNames []string
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships Relationships Relationships
@ -68,10 +51,9 @@ func (schema Schema) String() string {
} }
func (schema Schema) MakeSlice() reflect.Value { func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type()) results := reflect.New(slice.Type())
results.Elem().Set(slice) results.Elem().Set(slice)
return results return results
} }
@ -85,52 +67,17 @@ func (schema Schema) LookUpField(name string) *Field {
return nil return nil
} }
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface { type Tabler interface {
TableName() string TableName() string
} }
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil { if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
value := reflect.ValueOf(dest) modelType := reflect.Indirect(reflect.ValueOf(dest)).Type()
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface { if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
} }
@ -146,17 +93,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
// Cache the Schema for performance, if v, ok := cacheStore.Load(modelType); ok {
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
@ -168,33 +105,25 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
ModelType: modelType, ModelType: modelType,
Table: tableName, Table: tableName,
FieldsByName: map[string]*Field{}, FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{}, FieldsByDBName: map[string]*Field{},
FieldsByDBName: map[string]*Field{}, Relationships: Relationships{Relations: map[string]*Relationship{}},
Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore,
cacheStore: cacheStore, namer: namer,
namer: namer, initialized: make(chan struct{}),
initialized: make(chan struct{}),
} }
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
// Load exist schema cache, return if exists if v, loaded := cacheStore.Load(modelType); loaded {
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
@ -216,7 +145,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.DBName = namer.ColumnName(schema.Table, field.Name) field.DBName = namer.ColumnName(schema.Table, field.Name)
} }
bindName := field.BindName()
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
@ -225,7 +153,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey { if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields { for idx, f := range schema.PrimaryFields {
@ -244,11 +171,8 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
} }
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter(modelType) field.setupValuerAndSetter()
} }
prioritizedPrimaryField := schema.LookUpField("id") prioritizedPrimaryField := schema.LookUpField("id")
@ -266,26 +190,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
} }
if schema.PrioritizedPrimaryField == nil { if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
if len(schema.PrimaryFields) == 1 { schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
}
} }
for _, field := range schema.PrimaryFields { for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
} }
for _, field := range schema.Fields { for _, field := range schema.FieldsByDBName {
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { if field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
} }
} }
@ -304,32 +218,19 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
} }
} }
callbackTypes := []callbackType{ callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
callbackTypeBeforeCreate, callbackTypeAfterCreate, for _, name := range callbacks {
callbackTypeBeforeUpdate, callbackTypeAfterUpdate, if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() { switch methodValue.Type().String() {
case "func(*gorm.DB) error": case "func(*gorm.DB) error": // TODO hack
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
default: default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
} }
} }
} }
// Cache the schema if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
@ -345,12 +246,11 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields { for _, field := range schema.Fields {
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil { if schema.parseRelation(field); schema.err != nil {
return schema, schema.err return schema, schema.err
} else { } else {
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
} }
} }
@ -377,39 +277,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err return schema, schema.err
} }
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type() modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {

View File

@ -1,7 +1,6 @@
package schema_test package schema_test
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -163,8 +162,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
} }
for i := range relation.JoinTable.Fields { for _, f := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil) checkSchemaField(t, r.JoinTable, &f, nil)
} }
} }
@ -201,41 +200,10 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
}) })
} }
type EmbeddedRelations struct {
Relations map[string]Relation
EmbeddedRelations map[string]EmbeddedRelations
}
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
for name, relations := range actual {
rs := expected[name]
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
if len(relations.Relations) != len(rs.Relations) {
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
}
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
}
for n, rel := range relations.Relations {
if r, ok := rs.Relations[n]; !ok {
t.Errorf("failed to find relation by name %s", n)
} else {
checkSchemaRelation(t, &schema.Schema{
Relationships: schema.Relationships{
Relations: map[string]*schema.Relationship{n: rel},
},
}, r)
}
}
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
})
}
}
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {
fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) fv, _ := s.FieldsByDBName[k].ValueOf(value)
tests.AssertEqual(t, v, fv) tests.AssertEqual(t, v, fv)
}) })
} }

View File

@ -19,22 +19,6 @@ func TestParseSchema(t *testing.T) {
checkUserSchema(t, user) checkUserSchema(t, user)
} }
func TestParseSchemaWithMap(t *testing.T) {
type User struct {
tests.User
Attrs map[string]string `gorm:"type:Map(String,String);"`
}
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with map, got error %v", err)
}
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
t.Errorf("failed to parse user field Attrs")
}
}
func TestParseSchemaWithPointerFields(t *testing.T) { func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
@ -62,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
} }
for i := range fields { for _, f := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) { checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true f.Creatable = true
f.Updatable = true f.Updatable = true
f.Readable = true f.Readable = true
@ -152,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
} }
for i := range fields { for _, f := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) { checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true f.Creatable = true
f.Updatable = true f.Updatable = true
f.Readable = true f.Readable = true
@ -161,7 +145,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
} }
} }
type CustomizeTable struct{} type CustomizeTable struct {
}
func (CustomizeTable) TableName() string { func (CustomizeTable) TableName() string {
return "customize" return "customize"
@ -180,6 +165,7 @@ func TestCustomizeTableName(t *testing.T) {
func TestNestedModel(t *testing.T) { func TestNestedModel(t *testing.T) {
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse nested user, got error %v", err) t.Fatalf("failed to parse nested user, got error %v", err)
} }
@ -218,6 +204,7 @@ func TestEmbeddedStruct(t *testing.T) {
} }
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
} }
@ -286,6 +273,7 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
} }
cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}})
if err != nil { if err != nil {
t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) t.Fatalf("failed to parse embedded struct with primary key, got error %v", err)
} }
@ -309,44 +297,3 @@ func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) {
}) })
} }
} }
func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) {
type Product struct {
ProductID uint `gorm:"primaryKey;autoIncrement"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
type ProductNonAutoIncrement struct {
ProductID uint `gorm:"primaryKey;autoIncrement:false"`
LanguageCode uint `gorm:"primaryKey"`
Code string
Name string
}
product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse product struct with composite primary key, got error %v", err)
}
prioritizedPrimaryField := schema.Field{
Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"},
}
product.Fields = []*schema.Field{product.PrioritizedPrimaryField}
checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
})
productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err)
}
if productNonAutoIncrement.PrioritizedPrimaryField != nil {
t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil")
}
}

View File

@ -1,173 +0,0 @@
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value interface{}
fieldValue interface{}
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value interface{}) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}
// JSONSerializer json serializer
type JSONSerializer struct{}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
bytes, err = json.Marshal(v)
if err != nil {
return err
}
}
if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct{}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct{}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}

View File

@ -1,8 +1,6 @@
package schema package schema
import ( import (
"context"
"fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
@ -60,22 +58,14 @@ func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.Struct
return tag return tag
} }
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value // GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels { for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1) reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) { appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { if _, isZero := rel.Field.ValueOf(value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) result := reflect.Indirect(rel.Field.ReflectValueOf(value))
switch result.Kind() { switch result.Kind() {
case reflect.Struct: case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr()) reflectResults = reflect.Append(reflectResults, result.Addr())
@ -107,7 +97,7 @@ func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []
} }
// GetIdentityFieldValuesMap get identity map from fields // GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var ( var (
results = [][]interface{}{} results = [][]interface{}{}
dataResults = map[string][]reflect.Value{} dataResults = map[string][]reflect.Value{}
@ -115,17 +105,12 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
notZero, zero bool notZero, zero bool
) )
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))} results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields { for idx, field := range fields {
results[0][idx], zero = field.ValueOf(ctx, reflectValue) results[0][idx], zero = field.ValueOf(reflectValue)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -138,7 +123,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i) elem := reflectValue.Index(i)
elemKey := elem.Interface() elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr && elem.CanAddr() { if elem.Kind() != reflect.Ptr {
elemKey = elem.Addr().Interface() elemKey = elem.Addr().Interface()
} }
@ -150,7 +135,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
fieldValues := make([]interface{}, len(fields)) fieldValues := make([]interface{}, len(fields))
notZero = false notZero = false
for idx, field := range fields { for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(ctx, elem) fieldValues[idx], zero = field.ValueOf(elem)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -170,12 +155,12 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value,
} }
// GetIdentityFieldValuesMapFromValues get identity map from fields // GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{} resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{} results := [][]interface{}{}
for _, v := range values { for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) rm, rs := GetIdentityFieldValuesMap(reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm { for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...) resultsMap[k] = append(resultsMap[k], v...)
} }

View File

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"reflect" "reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
@ -46,21 +45,11 @@ func (n *DeletedAt) UnmarshalJSON(b []byte) error {
} }
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} return []clause.Interface{SoftDeleteQueryClause{Field: f}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
} }
type SoftDeleteQueryClause struct { type SoftDeleteQueryClause struct {
ZeroValue sql.NullString Field *schema.Field
Field *schema.Field
} }
func (sd SoftDeleteQueryClause) Name() string { func (sd SoftDeleteQueryClause) Name() string {
@ -74,9 +63,9 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
} }
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
if c, ok := stmt.Clauses["WHERE"]; ok { if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
for _, expr := range where.Exprs { for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)} where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
@ -89,19 +78,18 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{ stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
}}) }})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{} stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
} }
} }
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
} }
type SoftDeleteUpdateClause struct { type SoftDeleteUpdateClause struct {
ZeroValue sql.NullString Field *schema.Field
Field *schema.Field
} }
func (sd SoftDeleteUpdateClause) Name() string { func (sd SoftDeleteUpdateClause) Name() string {
@ -115,18 +103,19 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
} }
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { if stmt.SQL.String() == "" {
SoftDeleteQueryClause(sd).ModifyStatement(stmt) if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
} }
} }
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
} }
type SoftDeleteDeleteClause struct { type SoftDeleteDeleteClause struct {
ZeroValue sql.NullString Field *schema.Field
Field *schema.Field
} }
func (sd SoftDeleteDeleteClause) Name() string { func (sd SoftDeleteDeleteClause) Name() string {
@ -140,13 +129,13 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
} }
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { if stmt.SQL.String() == "" {
curTime := stmt.DB.NowFunc() curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.SetColumn(sd.Field.DBName, curTime, true) stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil { if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -154,7 +143,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 { if len(values) > 0 {
@ -163,8 +152,13 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
} }
} }
SoftDeleteQueryClause(sd).ModifyStatement(stmt) if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
stmt.AddClauseIfNotExists(clause.Update{}) stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...) stmt.Build("UPDATE", "SET", "WHERE")
} }
} }

View File

@ -30,9 +30,8 @@ type Statement struct {
Clauses map[string]clause.Clause Clauses map[string]clause.Clause
BuildClauses []string BuildClauses []string
Distinct bool Distinct bool
Selects []string // selected columns Selects []string // selected columns
Omits []string // omit columns Omits []string // omit columns
ColumnMapping map[string]string // map columns
Joins []join Joins []join
Preloads map[string][]interface{} Preloads map[string][]interface{}
Settings sync.Map Settings sync.Map
@ -47,18 +46,12 @@ type Statement struct {
attrs []interface{} attrs []interface{}
assigns []interface{} assigns []interface{}
scopes []func(*DB) *DB scopes []func(*DB) *DB
Result *result
} }
type join struct { type join struct {
Name string Name string
Alias string Conds []interface{}
Conds []interface{} On *clause.Where
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -124,8 +117,6 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 { } else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0]) write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
} }
} else { } else {
write(v.Raw, v.Name) write(v.Raw, v.Name)
@ -139,7 +130,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(') writer.WriteByte('(')
for idx, d := range v { for idx, d := range v {
if idx > 0 { if idx > 0 {
writer.WriteByte(',') writer.WriteString(",")
} }
stmt.QuoteTo(writer, d) stmt.QuoteTo(writer, d)
} }
@ -152,7 +143,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteByte('(') writer.WriteByte('(')
for idx, d := range v { for idx, d := range v {
if idx > 0 { if idx > 0 {
writer.WriteByte(',') writer.WriteString(",")
} }
stmt.DB.Dialector.QuoteTo(writer, d) stmt.DB.Dialector.QuoteTo(writer, d)
} }
@ -182,17 +173,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case clause.Column, clause.Table: case clause.Column, clause.Table:
stmt.QuoteTo(writer, v) stmt.QuoteTo(writer, v)
case Valuer: case Valuer:
reflectValue := reflect.ValueOf(v) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { case clause.Expr:
stmt.AddVar(writer, nil) v.Build(stmt)
} else { case *clause.Expr:
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:
stmt.Vars = append(stmt.Vars, v) stmt.Vars = append(stmt.Vars, v)
@ -208,21 +192,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
writer.WriteString("(NULL)") writer.WriteString("(NULL)")
} }
case interface{ getInstance() *DB }: case *DB:
cv := v.getInstance() subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
var ( var (
vars = subdb.Statement.Vars vars = subdb.Statement.Vars
sql = cv.Statement.SQL.String() sql = v.Statement.SQL.String()
) )
subdb.Statement.Vars = make([]interface{}, 0, len(vars)) subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars { for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv) subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{} bindvar := strings.Builder{}
cv.BindVarTo(&bindvar, subdb.Statement, vv) v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1) sql = strings.Replace(sql, bindvar.String(), "?", 1)
} }
@ -245,9 +227,6 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if rv.Len() == 0 { if rv.Len() == 0 {
writer.WriteString("(NULL)") writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else { } else {
writer.WriteByte('(') writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ { for i := 0; i < rv.Len(); i++ {
@ -305,11 +284,6 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
} }
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 { if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
} }
@ -319,31 +293,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds := make([]clause.Expression, 0, 4) conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...) args = append([]interface{}{query}, args...)
for idx, arg := range args { for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok { if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) { switch v := arg.(type) {
case clause.Expression: case clause.Expression:
conds = append(conds, v) conds = append(conds, v)
case *DB: case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok { if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok { if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 { if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
if len(orConds.Exprs) == 1 { where.Exprs[0] = clause.AndConditions(orConds)
where.Exprs[0] = clause.AndConditions(orConds)
}
} }
} }
conds = append(conds, clause.And(where.Exprs...)) conds = append(conds, clause.And(where.Exprs...))
@ -356,21 +318,17 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
conds = append(conds, clause.Eq{Column: i, Value: j}) conds = append(conds, clause.Eq{Column: i, Value: j})
} }
case map[string]string: case map[string]string:
keys := make([]string, 0, len(v)) var keys = make([]string, 0, len(v))
for i := range v { for i := range v {
keys = append(keys, i) keys = append(keys, i)
} }
sort.Strings(keys) sort.Strings(keys)
for _, key := range keys { for _, key := range keys {
column := clause.Column{Name: key, Table: curTable} conds = append(conds, clause.Eq{Column: key, Value: v[key]})
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} }
case map[string]interface{}: case map[string]interface{}:
keys := make([]string, 0, len(v)) var keys = make([]string, 0, len(v))
for i := range v { for i := range v {
keys = append(keys, i) keys = append(keys, i)
} }
@ -378,16 +336,12 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys { for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok { if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: column, Value: v[key]}) conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok { } else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: column, Value: v[key]}) conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else { } else {
// optimize reflect value length // optimize reflect value length
valueLen := reflectValue.Len() valueLen := reflectValue.Len()
@ -396,10 +350,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface() values[i] = reflectValue.Index(i).Interface()
} }
conds = append(conds, clause.IN{Column: column, Values: values}) conds = append(conds, clause.IN{Column: key, Values: values})
} }
default: default:
conds = append(conds, clause.Eq{Column: column, Value: v[key]}) conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} }
} }
default: default:
@ -424,11 +378,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -438,11 +392,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, field := range s.Fields { for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name] selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) { if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" { if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" { } else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
} }
} }
} }
@ -467,22 +421,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
} }
if len(values) > 0 { if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
return []clause.Expression{clause.And(conds...)}
} }
return nil return conds
} }
} }
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
} }
} }
} }
if len(conds) > 0 { return conds
return []clause.Expression{clause.And(conds...)}
}
return nil
} }
// Build build sql with clauses names // Build build sql with clauses names
@ -506,11 +456,7 @@ func (stmt *Statement) Build(clauses ...string) {
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithSpecialTableName(value, "") if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1] stmt.Table = tables[1]
@ -534,14 +480,12 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct, Distinct: stmt.Distinct,
Selects: stmt.Selects, Selects: stmt.Selects,
Omits: stmt.Omits, Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{}, Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool, ConnPool: stmt.ConnPool,
Schema: stmt.Schema, Schema: stmt.Schema,
Context: stmt.Context, Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks, SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
} }
if stmt.SQL.Len() > 0 { if stmt.SQL.Len() > 0 {
@ -577,9 +521,8 @@ func (stmt *Statement) clone() *Statement {
} }
// SetColumn set column's value // SetColumn set column's value
// // stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok { if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value v[name] = value
@ -604,7 +547,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
switch destValue.Kind() { switch destValue.Kind() {
case reflect.Struct: case reflect.Struct:
stmt.AddError(field.Set(stmt.Context, destValue, value)) field.Set(destValue, value)
default: default:
stmt.AddError(ErrInvalidData) stmt.AddError(ErrInvalidData)
} }
@ -614,10 +557,10 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 { if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ { for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) field.Set(stmt.ReflectValue.Index(i), value)
} }
} else { } else {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
} }
case reflect.Struct: case reflect.Struct:
if !stmt.ReflectValue.CanAddr() { if !stmt.ReflectValue.CanAddr() {
@ -625,7 +568,7 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
return return
} }
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) field.Set(stmt.ReflectValue, value)
} }
} else { } else {
stmt.AddError(ErrInvalidField) stmt.AddError(ErrInvalidField)
@ -645,12 +588,12 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool { changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(stmt.Context, modelValue) fieldValue, _ := field.ValueOf(modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if mv, mok := stmt.Dest.(map[string]interface{}); mok { if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := mv[field.Name]; ok { if fv, ok := v[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := mv[field.DBName]; ok { } else if fv, ok := v[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue) return !utils.AssertEqual(fv, fieldValue)
} }
} else { } else {
@ -659,10 +602,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
destValue = destValue.Elem() destValue = destValue.Elem()
} }
changedValue, zero := field.ValueOf(stmt.Context, destValue) changedValue, zero := field.ValueOf(destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue) return !zero && !utils.AssertEqual(changedValue, fieldValue)
} }
} }
@ -688,62 +628,50 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false return false
} }
var matchName = func() func(tableColumn string) (table, column string) { var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{} results := map[string]bool{}
notRestricted := false notRestricted := false
processColumn := func(column string, result bool) { // select columns
for _, column := range stmt.Selects {
if stmt.Schema == nil { if stmt.Schema == nil {
results[column] = result results[column] = true
} else if column == "*" { } else if column == "*" {
notRestricted = result notRestricted = true
for _, dbName := range stmt.Schema.DBNames { for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result results[dbName] = true
} }
} else if column == clause.Associations { } else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = result results[rel.Name] = true
} }
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result results[field.DBName] = true
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
if col == "*" { results[matches[1]] = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else { } else {
results[column] = result results[column] = true
} }
} }
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns // omit columns
for _, column := range stmt.Omits { for _, omit := range stmt.Omits {
processColumn(column, false) if stmt.Schema == nil {
results[omit] = false
} else if omit == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = false
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
results[matches[1]] = false
} else {
results[omit] = false
}
} }
if stmt.Schema != nil { if stmt.Schema != nil {

View File

@ -35,36 +35,15 @@ func TestWhereCloneCorruption(t *testing.T) {
} }
} }
func TestNilCondition(t *testing.T) {
s := new(Statement)
if len(s.BuildCondition(nil)) != 0 {
t.Errorf("Nil condition should be empty")
}
}
func TestNameMatcher(t *testing.T) { func TestNameMatcher(t *testing.T) {
for k, v := range map[string][]string{ for k, v := range map[string]string{
"table.name": {"table", "name"}, "table.name": "name",
"`table`.`name`": {"table", "name"}, "`table`.`name`": "name",
"'table'.'name'": {"table", "name"}, "'table'.'name'": "name",
"'table'.name": {"table", "name"}, "'table'.name": "name",
"table1.name_23": {"table1", "name_23"},
"`table_1`.`name23`": {"table_1", "name23"},
"'table23'.'name_1'": {"table23", "name_1"},
"'table23'.name1": {"table23", "name1"},
"'name1'": {"", "name1"},
"`name_1`": {"", "name_1"},
"`Name_1`": {"", "Name_1"},
"`Table`.`nAme`": {"Table", "nAme"},
"my_table.*": {"my_table", "*"},
"`my_table`.*": {"my_table", "*"},
"User__Company.*": {"User__Company", "*"},
"`User__Company`.*": {"User__Company", "*"},
`"User__Company".*`: {"User__Company", "*"},
`"table"."*"`: {"", ""},
} { } {
if table, column := matchName(k); table != v[0] || column != v[1] { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
} }
} }
} }

View File

@ -3,12 +3,11 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
func TestBelongsToAssociation(t *testing.T) { func TestBelongsToAssociation(t *testing.T) {
user := *GetUser("belongs-to", Config{Company: true, Manager: true}) var user = *GetUser("belongs-to", Config{Company: true, Manager: true})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -32,8 +31,8 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Manager", 1, "") AssertAssociationCount(t, user, "Manager", 1, "")
// Append // Append
company := Company{Name: "company-belongs-to-append"} var company = Company{Name: "company-belongs-to-append"}
manager := GetUser("manager-belongs-to-append", Config{}) var manager = GetUser("manager-belongs-to-append", Config{})
if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { if err := DB.Model(&user2).Association("Company").Append(&company); err != nil {
t.Fatalf("Error happened when append Company, got %v", err) t.Fatalf("Error happened when append Company, got %v", err)
@ -61,8 +60,8 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend")
// Replace // Replace
company2 := Company{Name: "company-belongs-to-replace"} var company2 = Company{Name: "company-belongs-to-replace"}
manager2 := GetUser("manager-belongs-to-replace", Config{}) var manager2 = GetUser("manager-belongs-to-replace", Config{})
if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil {
t.Fatalf("Error happened when replace Company, got %v", err) t.Fatalf("Error happened when replace Company, got %v", err)
@ -133,18 +132,10 @@ func TestBelongsToAssociation(t *testing.T) {
AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Company", 0, "after clear")
AssertAssociationCount(t, user2, "Manager", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear")
// unexist company id
unexistCompanyID := company.ID + 9999999
user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID}
if err := DB.Create(&user).Error; err == nil {
tidbSkip(t, "not support the foreign key feature")
t.Errorf("should have gotten foreign key violation error")
}
} }
func TestBelongsToAssociationForSlice(t *testing.T) { func TestBelongsToAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}),
*GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}),
*GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}),
@ -226,81 +217,3 @@ func TestBelongsToAssociationForSlice(t *testing.T) {
AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[0], "Company", 0, "After Delete")
AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete")
} }
func TestBelongsToDefaultValue(t *testing.T) {
type Org struct {
ID string
}
type BelongsToUser struct {
OrgID string
Org Org `gorm:"default:NULL"`
}
tx := DB.Session(&gorm.Session{})
tx.Config.DisableForeignKeyConstraintWhenMigrating = true
AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false)
tx.Migrator().DropTable(&BelongsToUser{}, &Org{})
tx.AutoMigrate(&BelongsToUser{}, &Org{})
user := &BelongsToUser{
Org: Org{
ID: "BelongsToUser_Org_1",
},
}
err := DB.Create(&user).Error
AssertEqual(t, err, nil)
}
func TestBelongsToAssociationUnscoped(t *testing.T) {
type ItemParent struct {
gorm.Model
Logo string `gorm:"not null;type:varchar(50)"`
}
type ItemChild struct {
gorm.Model
Name string `gorm:"type:varchar(50)"`
ItemParentID uint
ItemParent ItemParent
}
tx := DB.Session(&gorm.Session{})
tx.Migrator().DropTable(&ItemParent{}, &ItemChild{})
tx.AutoMigrate(&ItemParent{}, &ItemChild{})
item := ItemChild{
Name: "name",
ItemParent: ItemParent{
Logo: "logo",
},
}
if err := tx.Create(&item).Error; err != nil {
t.Fatalf("failed to create items, got error: %v", err)
}
// test replace
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
Logo: "updated logo",
}); err != nil {
t.Errorf("failed to replace item parent, got error: %v", err)
}
var parents []ItemParent
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 1 {
t.Errorf("expected %d parents, got %d", 1, len(parents))
}
// test delete
if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil {
t.Errorf("failed to delete item parent, got error: %v", err)
}
if err := tx.Find(&parents).Error; err != nil {
t.Errorf("failed to find item parent, got error: %v", err)
}
if len(parents) != 0 {
t.Errorf("expected %d parents, got %d", 0, len(parents))
}
}

View File

@ -3,12 +3,11 @@ package tests_test
import ( import (
"testing" "testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
func TestHasManyAssociation(t *testing.T) { func TestHasManyAssociation(t *testing.T) {
user := *GetUser("hasmany", Config{Pets: 2}) var user = *GetUser("hasmany", Config{Pets: 2})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -43,7 +42,7 @@ func TestHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Pets", 2, "") AssertAssociationCount(t, user, "Pets", 2, "")
// Append // Append
pet := Pet{Name: "pet-has-many-append"} var pet = Pet{Name: "pet-has-many-append"}
if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
@ -58,14 +57,14 @@ func TestHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") AssertAssociationCount(t, user, "Pets", 3, "AfterAppend")
pets2 := []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}}
if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil {
t.Fatalf("Error happened when append pet, got %v", err) t.Fatalf("Error happened when append pet, got %v", err)
} }
for _, pet := range pets2 { for _, pet := range pets2 {
pet := pet var pet = pet
if pet.ID == 0 { if pet.ID == 0 {
t.Fatalf("Pet's ID should be created") t.Fatalf("Pet's ID should be created")
} }
@ -78,7 +77,7 @@ func TestHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice")
// Replace // Replace
pet2 := Pet{Name: "pet-has-many-replace"} var pet2 = Pet{Name: "pet-has-many-replace"}
if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil {
t.Fatalf("Error happened when append pet, got %v", err) t.Fatalf("Error happened when append pet, got %v", err)
@ -120,7 +119,7 @@ func TestHasManyAssociation(t *testing.T) {
} }
func TestSingleTableHasManyAssociation(t *testing.T) { func TestSingleTableHasManyAssociation(t *testing.T) {
user := *GetUser("hasmany", Config{Team: 2}) var user = *GetUser("hasmany", Config{Team: 2})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -138,7 +137,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Team", 2, "") AssertAssociationCount(t, user, "Team", 2, "")
// Append // Append
team := *GetUser("team", Config{}) var team = *GetUser("team", Config{})
if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { if err := DB.Model(&user2).Association("Team").Append(&team); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
@ -153,14 +152,14 @@ func TestSingleTableHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Team", 3, "AfterAppend") AssertAssociationCount(t, user, "Team", 3, "AfterAppend")
teams := []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} var teams = []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})}
if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil {
t.Fatalf("Error happened when append team, got %v", err) t.Fatalf("Error happened when append team, got %v", err)
} }
for _, team := range teams { for _, team := range teams {
team := team var team = team
if team.ID == 0 { if team.ID == 0 {
t.Fatalf("Team's ID should be created") t.Fatalf("Team's ID should be created")
} }
@ -173,7 +172,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice")
// Replace // Replace
team2 := *GetUser("team-replace", Config{}) var team2 = *GetUser("team-replace", Config{})
if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil {
t.Fatalf("Error happened when append team, got %v", err) t.Fatalf("Error happened when append team, got %v", err)
@ -215,7 +214,7 @@ func TestSingleTableHasManyAssociation(t *testing.T) {
} }
func TestHasManyAssociationForSlice(t *testing.T) { func TestHasManyAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-hasmany-1", Config{Pets: 2}), *GetUser("slice-hasmany-1", Config{Pets: 2}),
*GetUser("slice-hasmany-2", Config{Pets: 0}), *GetUser("slice-hasmany-2", Config{Pets: 0}),
*GetUser("slice-hasmany-3", Config{Pets: 4}), *GetUser("slice-hasmany-3", Config{Pets: 4}),
@ -269,7 +268,7 @@ func TestHasManyAssociationForSlice(t *testing.T) {
} }
func TestSingleTableHasManyAssociationForSlice(t *testing.T) { func TestSingleTableHasManyAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-hasmany-1", Config{Team: 2}), *GetUser("slice-hasmany-1", Config{Team: 2}),
*GetUser("slice-hasmany-2", Config{Team: 0}), *GetUser("slice-hasmany-2", Config{Team: 0}),
*GetUser("slice-hasmany-3", Config{Team: 4}), *GetUser("slice-hasmany-3", Config{Team: 4}),
@ -325,7 +324,7 @@ func TestSingleTableHasManyAssociationForSlice(t *testing.T) {
} }
func TestPolymorphicHasManyAssociation(t *testing.T) { func TestPolymorphicHasManyAssociation(t *testing.T) {
user := *GetUser("hasmany", Config{Toys: 2}) var user = *GetUser("hasmany", Config{Toys: 2})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -343,7 +342,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Toys", 2, "") AssertAssociationCount(t, user, "Toys", 2, "")
// Append // Append
toy := Toy{Name: "toy-has-many-append"} var toy = Toy{Name: "toy-has-many-append"}
if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
@ -358,14 +357,14 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") AssertAssociationCount(t, user, "Toys", 3, "AfterAppend")
toys := []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} var toys = []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}}
if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil {
t.Fatalf("Error happened when append toy, got %v", err) t.Fatalf("Error happened when append toy, got %v", err)
} }
for _, toy := range toys { for _, toy := range toys {
toy := toy var toy = toy
if toy.ID == 0 { if toy.ID == 0 {
t.Fatalf("Toy's ID should be created") t.Fatalf("Toy's ID should be created")
} }
@ -378,7 +377,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice")
// Replace // Replace
toy2 := Toy{Name: "toy-has-many-replace"} var toy2 = Toy{Name: "toy-has-many-replace"}
if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil {
t.Fatalf("Error happened when append toy, got %v", err) t.Fatalf("Error happened when append toy, got %v", err)
@ -420,9 +419,9 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
} }
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-hasmany-1", Config{Toys: 2}), *GetUser("slice-hasmany-1", Config{Toys: 2}),
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-2", Config{Toys: 0}),
*GetUser("slice-hasmany-3", Config{Toys: 4}), *GetUser("slice-hasmany-3", Config{Toys: 4}),
} }
@ -430,7 +429,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
// Count // Count
AssertAssociationCount(t, users, "Toys", 6, "") AssertAssociationCount(t, users, "Toys", 6, "")
AssertAssociationCount(t, users, "Tools", 2, "")
// Find // Find
var toys []Toy var toys []Toy
@ -438,14 +436,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
t.Errorf("toys count should be %v, but got %v", 6, len(toys)) t.Errorf("toys count should be %v, but got %v", 6, len(toys))
} }
// Find Tools (polymorphic with custom type and id)
var tools []Tools
DB.Model(&users).Association("Tools").Find(&tools)
if len(tools) != 2 {
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
}
// Append // Append
DB.Model(&users).Association("Toys").Append( DB.Model(&users).Association("Toys").Append(
&Toy{Name: "toy-slice-append-1"}, &Toy{Name: "toy-slice-append-1"},
@ -481,88 +471,3 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Toys").Clear() DB.Model(&users).Association("Toys").Clear()
AssertAssociationCount(t, users, "Toys", 0, "After Clear") AssertAssociationCount(t, users, "Toys", 0, "After Clear")
} }
func TestHasManyAssociationUnscoped(t *testing.T) {
type ItemContent struct {
gorm.Model
ItemID uint `gorm:"not null"`
Name string `gorm:"not null;type:varchar(50)"`
LanguageCode string `gorm:"not null;type:varchar(2)"`
}
type Item struct {
gorm.Model
Logo string `gorm:"not null;type:varchar(50)"`
Contents []ItemContent `gorm:"foreignKey:ItemID"`
}
tx := DB.Session(&gorm.Session{})
tx.Migrator().DropTable(&ItemContent{}, &Item{})
tx.AutoMigrate(&ItemContent{}, &Item{})
item := Item{
Logo: "logo",
Contents: []ItemContent{
{Name: "name", LanguageCode: "en"},
{Name: "ar name", LanguageCode: "ar"},
},
}
if err := tx.Create(&item).Error; err != nil {
t.Fatalf("failed to create items, got error: %v", err)
}
// test Replace
if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{
{Name: "updated name", LanguageCode: "en"},
{Name: "ar updated name", LanguageCode: "ar"},
{Name: "le nom", LanguageCode: "fr"},
}); err != nil {
t.Errorf("failed to replace item content, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 3 {
t.Errorf("expected %d contents, got %d", 3, count)
}
var contents []ItemContent
if err := tx.Find(&contents).Error; err != nil {
t.Errorf("failed to find contents, got error: %v", err)
}
if len(contents) != 3 {
t.Errorf("expected %d contents, got %d", 3, len(contents))
}
// test delete
if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil {
t.Errorf("failed to delete Contents, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 2 {
t.Errorf("expected %d contents, got %d", 2, count)
}
// test clear
if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil {
t.Errorf("failed to clear contents association, got error: %v", err)
}
if count := tx.Model(&item).Association("Contents").Count(); count != 0 {
t.Errorf("expected %d contents, got %d", 0, count)
}
if err := tx.Find(&contents).Error; err != nil {
t.Errorf("failed to find contents, got error: %v", err)
}
if len(contents) != 0 {
t.Errorf("expected %d contents, got %d", 0, len(contents))
}
}
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -7,7 +7,7 @@ import (
) )
func TestHasOneAssociation(t *testing.T) { func TestHasOneAssociation(t *testing.T) {
user := *GetUser("hasone", Config{Account: true}) var user = *GetUser("hasone", Config{Account: true})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -25,7 +25,7 @@ func TestHasOneAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Account", 1, "") AssertAssociationCount(t, user, "Account", 1, "")
// Append // Append
account := Account{Number: "account-has-one-append"} var account = Account{Number: "account-has-one-append"}
if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { if err := DB.Model(&user2).Association("Account").Append(&account); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
@ -41,7 +41,7 @@ func TestHasOneAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Account", 1, "AfterAppend") AssertAssociationCount(t, user, "Account", 1, "AfterAppend")
// Replace // Replace
account2 := Account{Number: "account-has-one-replace"} var account2 = Account{Number: "account-has-one-replace"}
if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil {
t.Fatalf("Error happened when append Account, got %v", err) t.Fatalf("Error happened when append Account, got %v", err)
@ -84,7 +84,7 @@ func TestHasOneAssociation(t *testing.T) {
} }
func TestHasOneAssociationWithSelect(t *testing.T) { func TestHasOneAssociationWithSelect(t *testing.T) {
user := *GetUser("hasone", Config{Account: true}) var user = *GetUser("hasone", Config{Account: true})
DB.Omit("Account.Number").Create(&user) DB.Omit("Account.Number").Create(&user)
@ -98,7 +98,7 @@ func TestHasOneAssociationWithSelect(t *testing.T) {
} }
func TestHasOneAssociationForSlice(t *testing.T) { func TestHasOneAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-hasone-1", Config{Account: true}), *GetUser("slice-hasone-1", Config{Account: true}),
*GetUser("slice-hasone-2", Config{Account: false}), *GetUser("slice-hasone-2", Config{Account: false}),
*GetUser("slice-hasone-3", Config{Account: true}), *GetUser("slice-hasone-3", Config{Account: true}),
@ -139,7 +139,7 @@ func TestHasOneAssociationForSlice(t *testing.T) {
} }
func TestPolymorphicHasOneAssociation(t *testing.T) { func TestPolymorphicHasOneAssociation(t *testing.T) {
pet := Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} var pet = Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}}
if err := DB.Create(&pet).Error; err != nil { if err := DB.Create(&pet).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -157,7 +157,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) {
AssertAssociationCount(t, pet, "Toy", 1, "") AssertAssociationCount(t, pet, "Toy", 1, "")
// Append // Append
toy := Toy{Name: "toy-has-one-append"} var toy = Toy{Name: "toy-has-one-append"}
if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil {
t.Fatalf("Error happened when append toy, got %v", err) t.Fatalf("Error happened when append toy, got %v", err)
@ -173,7 +173,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) {
AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend")
// Replace // Replace
toy2 := Toy{Name: "toy-has-one-replace"} var toy2 = Toy{Name: "toy-has-one-replace"}
if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil {
t.Fatalf("Error happened when append Toy, got %v", err) t.Fatalf("Error happened when append Toy, got %v", err)
@ -216,7 +216,7 @@ func TestPolymorphicHasOneAssociation(t *testing.T) {
} }
func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
pets := []Pet{ var pets = []Pet{
{Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}},
{Name: "hasone-2", Toy: Toy{}}, {Name: "hasone-2", Toy: Toy{}},
{Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}},
@ -255,15 +255,3 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
DB.Model(&pets).Association("Toy").Clear() DB.Model(&pets).Association("Toy").Clear()
AssertAssociationCount(t, pets, "Toy", 0, "After Clear") AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
} }
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Account: Account{Number: "1"}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -1,17 +1,13 @@
package tests_test package tests_test
import ( import (
"fmt"
"sync"
"testing" "testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
func TestMany2ManyAssociation(t *testing.T) { func TestMany2ManyAssociation(t *testing.T) {
user := *GetUser("many2many", Config{Languages: 2}) var user = *GetUser("many2many", Config{Languages: 2})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -30,7 +26,7 @@ func TestMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Languages", 2, "") AssertAssociationCount(t, user, "Languages", 2, "")
// Append // Append
language := Language{Code: "language-many2many-append", Name: "language-many2many-append"} var language = Language{Code: "language-many2many-append", Name: "language-many2many-append"}
DB.Create(&language) DB.Create(&language)
if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil {
@ -42,7 +38,7 @@ func TestMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") AssertAssociationCount(t, user, "Languages", 3, "AfterAppend")
languages := []Language{ var languages = []Language{
{Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"},
{Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"},
} }
@ -59,7 +55,7 @@ func TestMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice")
// Replace // Replace
language2 := Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} var language2 = Language{Code: "language-many2many-replace", Name: "language-many2many-replace"}
DB.Create(&language2) DB.Create(&language2)
if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil {
@ -98,9 +94,7 @@ func TestMany2ManyAssociation(t *testing.T) {
} }
func TestMany2ManyOmitAssociations(t *testing.T) { func TestMany2ManyOmitAssociations(t *testing.T) {
tidbSkip(t, "not support the foreign key feature") var user = *GetUser("many2many_omit_associations", Config{Languages: 2})
user := *GetUser("many2many_omit_associations", Config{Languages: 2})
if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { if err := DB.Omit("Languages.*").Create(&user).Error; err == nil {
t.Fatalf("should raise error when create users without languages reference") t.Fatalf("should raise error when create users without languages reference")
@ -120,14 +114,14 @@ func TestMany2ManyOmitAssociations(t *testing.T) {
t.Errorf("languages count should be %v, but got %v", 2, len(languages)) t.Errorf("languages count should be %v, but got %v", 2, len(languages))
} }
newLang := Language{Code: "omitmany2many", Name: "omitmany2many"} var newLang = Language{Code: "omitmany2many", Name: "omitmany2many"}
if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil {
t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) t.Errorf("should failed to insert languages due to constraint failed, error: %v", err)
} }
} }
func TestMany2ManyAssociationForSlice(t *testing.T) { func TestMany2ManyAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-many2many-1", Config{Languages: 2}), *GetUser("slice-many2many-1", Config{Languages: 2}),
*GetUser("slice-many2many-2", Config{Languages: 0}), *GetUser("slice-many2many-2", Config{Languages: 0}),
*GetUser("slice-many2many-3", Config{Languages: 4}), *GetUser("slice-many2many-3", Config{Languages: 4}),
@ -145,11 +139,11 @@ func TestMany2ManyAssociationForSlice(t *testing.T) {
} }
// Append // Append
languages1 := []Language{ var languages1 = []Language{
{Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, {Code: "language-many2many-append-1", Name: "language-many2many-append-1"},
} }
languages2 := []Language{} var languages2 = []Language{}
languages3 := []Language{ var languages3 = []Language{
{Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"},
{Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"},
} }
@ -197,7 +191,7 @@ func TestMany2ManyAssociationForSlice(t *testing.T) {
} }
func TestSingleTableMany2ManyAssociation(t *testing.T) { func TestSingleTableMany2ManyAssociation(t *testing.T) {
user := *GetUser("many2many", Config{Friends: 2}) var user = *GetUser("many2many", Config{Friends: 2})
if err := DB.Create(&user).Error; err != nil { if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err) t.Fatalf("errors happened when create: %v", err)
@ -216,7 +210,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Friends", 2, "") AssertAssociationCount(t, user, "Friends", 2, "")
// Append // Append
friend := *GetUser("friend", Config{}) var friend = *GetUser("friend", Config{})
if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil {
t.Fatalf("Error happened when append account, got %v", err) t.Fatalf("Error happened when append account, got %v", err)
@ -227,7 +221,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") AssertAssociationCount(t, user, "Friends", 3, "AfterAppend")
friends := []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} var friends = []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})}
if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil {
t.Fatalf("Error happened when append friend, got %v", err) t.Fatalf("Error happened when append friend, got %v", err)
@ -240,7 +234,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) {
AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice")
// Replace // Replace
friend2 := *GetUser("friend-replace-2", Config{}) var friend2 = *GetUser("friend-replace-2", Config{})
if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil {
t.Fatalf("Error happened when append friend, got %v", err) t.Fatalf("Error happened when append friend, got %v", err)
@ -278,7 +272,7 @@ func TestSingleTableMany2ManyAssociation(t *testing.T) {
} }
func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
users := []User{ var users = []User{
*GetUser("slice-many2many-1", Config{Team: 2}), *GetUser("slice-many2many-1", Config{Team: 2}),
*GetUser("slice-many2many-2", Config{Team: 0}), *GetUser("slice-many2many-2", Config{Team: 0}),
*GetUser("slice-many2many-3", Config{Team: 4}), *GetUser("slice-many2many-3", Config{Team: 4}),
@ -296,17 +290,17 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
} }
// Append // Append
teams1 := []User{*GetUser("friend-append-1", Config{})} var teams1 = []User{*GetUser("friend-append-1", Config{})}
teams2 := []User{} var teams2 = []User{}
teams3 := []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} var teams3 = []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})}
DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3)
AssertAssociationCount(t, users, "Team", 9, "After Append") AssertAssociationCount(t, users, "Team", 9, "After Append")
teams2_1 := []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} var teams2_1 = []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})}
teams2_2 := []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} var teams2_2 = []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})}
teams2_3 := GetUser("friend-replace-3-1", Config{}) var teams2_3 = GetUser("friend-replace-3-1", Config{})
// Replace // Replace
DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3)
@ -330,96 +324,3 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
DB.Model(&users).Association("Team").Clear() DB.Model(&users).Association("Team").Clear()
AssertAssociationCount(t, users, "Team", 0, "After Clear") AssertAssociationCount(t, users, "Team", 0, "After Clear")
} }
func TestDuplicateMany2ManyAssociation(t *testing.T) {
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
}}
user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-3"},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}
func TestConcurrentMany2ManyAssociation(t *testing.T) {
db, err := OpenTestConnection(&gorm.Config{})
if err != nil {
t.Fatalf("open test connection failed, err: %+v", err)
}
count := 3
var languages []Language
for i := 0; i < count; i++ {
language := Language{Code: fmt.Sprintf("consurrent %d", i)}
db.Create(&language)
languages = append(languages, language)
}
user := User{}
db.Create(&user)
db.Preload("Languages").FirstOrCreate(&user)
var wg sync.WaitGroup
for i := 0; i < count; i++ {
wg.Add(1)
go func(user User, language Language) {
err := db.Model(&user).Association("Languages").Append(&language)
AssertEqual(t, err, nil)
wg.Done()
}(user, languages[i])
}
wg.Wait()
var find User
err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error
AssertEqual(t, err, nil)
AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append")
}
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{
ID: 1,
Name: "Test-company-1",
}},
}}
users := []*User{&user1, &user2}
var err error
err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error
AssertEqual(t, nil, err)
var findUser1 User
err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error
AssertEqual(t, nil, err)
AssertEqual(t, user1, findUser1)
var findUser2 User
err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error
AssertEqual(t, nil, err)
AssertEqual(t, user2, findUser2)
}

Some files were not shown because too many files have changed in this diff Show More