Compare commits

..

No commits in common. "master" and "v0.2.0" have entirely different histories.

175 changed files with 3584 additions and 26077 deletions

5
.github/FUNDING.yml vendored
View File

@ -1,5 +0,0 @@
# These are supported funding model platforms
github: [jinzhu]
patreon: jinzhu
open_collective: gorm

45
.github/ISSUE_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,45 @@
Your issue may already be reported! Please search on the [issue track](https://github.com/jinzhu/gorm/issues) before creating one.
### What version of Go are you using (`go version`)?
### Which database and its version are you using?
### Please provide a complete runnable program to reproduce your issue. **IMPORTANT**
Need to runnable with [GORM's docker compose config](https://github.com/jinzhu/gorm/blob/master/docker-compose.yml) or please provides your config.
```go
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
var db *gorm.DB
func init() {
var err error
db, err = gorm.Open("sqlite3", "test.db")
// db, err = gorm.Open("postgres", "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable")
// db, err = gorm.Open("mysql", "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True")
// db, err = gorm.Open("mssql", "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm")
if err != nil {
panic(err)
}
db.LogMode(true)
}
func main() {
if /* failure condition */ {
fmt.Println("failed")
} else {
fmt.Println("success")
}
}
```

9
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,9 @@
Make sure these boxes checked before submitting your pull request.
- [] Do only one thing
- [] No API-breaking changes
- [] New code/logic commented & tested
For significant changes like big bug fixes, new features, please open an issue to make an agreement on an implementation design/plan first before starting it.
### What did this pull request do?

View File

@ -1,15 +0,0 @@
---
version: 2
updates:
- package-ecosystem: gomod
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
- package-ecosystem: gomod
directory: /tests
schedule:
interval: weekly

166
.github/labels.json vendored
View File

@ -1,166 +0,0 @@
{
"labels": {
"critical": {
"name": "type:critical",
"colour": "#E84137",
"description": "critical questions"
},
"question": {
"name": "type:question",
"colour": "#EDEDED",
"description": "general questions"
},
"feature": {
"name": "type:feature_request",
"colour": "#43952A",
"description": "feature request"
},
"invalid_question": {
"name": "type:invalid question",
"colour": "#CF2E1F",
"description": "invalid question (not related to GORM or described in document or not enough information provided)"
},
"with_playground": {
"name": "type:with reproduction steps",
"colour": "#00ff00",
"description": "with reproduction steps"
},
"without_playground": {
"name": "type:missing reproduction steps",
"colour": "#CF2E1F",
"description": "missing reproduction steps"
},
"has_pr": {
"name": "type:has pull request",
"colour": "#43952A",
"description": "has pull request"
},
"not_tested": {
"name": "type:not tested",
"colour": "#CF2E1F",
"description": "not tested"
},
"tested": {
"name": "type:tested",
"colour": "#00ff00",
"description": "tested"
},
"breaking_change": {
"name": "type:breaking change",
"colour": "#CF2E1F",
"description": "breaking change"
}
},
"issue": {
"with_playground": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s"
}
]
},
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"question": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/question/i"
},
{
"type": "descriptionMatches",
"pattern": "/question/i"
}
]
},
"feature": {
"requires": 1,
"conditions": [
{
"type": "titleMatches",
"pattern": "/feature/i"
},
{
"type": "descriptionMatches",
"pattern": "/Describe the feature/i"
}
]
},
"without_playground": {
"requires": 6,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s"
},
{
"type": "titleMatches",
"pattern": "/^((?!question).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!question).)*$/is"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!Describe the feature).)*$/is"
},
{
"type": "titleMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
},
{
"type": "descriptionMatches",
"pattern": "/^((?!critical|urgent).)*$/s"
}
]
}
},
"pr": {
"critical": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/(critical|urgent)/i"
},
{
"type": "titleMatches",
"pattern": "/(critical|urgent)/i"
}
]
},
"not_tested": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Tested/"
}
]
},
"breaking_change": {
"requires": 1,
"conditions": [
{
"type": "descriptionMatches",
"pattern": "/\\[\\] Non breaking API changes/"
}
]
}
}
}

View File

@ -1,20 +0,0 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

View File

@ -1,31 +0,0 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,26 +0,0 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -1,28 +0,0 @@
name: "Close invalid questions issues"
on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:invalid question"

View File

@ -1,19 +0,0 @@
name: "Issue Labeler"
on:
issues:
types: [opened, edited, reopened]
pull_request:
types: [opened, edited, reopened]
jobs:
triage:
runs-on: ubuntu-latest
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v4
- name: labeler
uses: jinzhu/super-labeler-action@develop
with:
GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -1,27 +0,0 @@
name: "Close Missing Playground issues"
on:
schedule:
- cron: "*/10 * * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
stale-issue-label: "status:stale"
days-before-stale: 0
days-before-close: 30
remove-stale-when-updated: true
only-labels: "type:missing reproduction steps"

View File

@ -1,28 +0,0 @@
name: "Stale"
on:
schedule:
- cron: "0 2 * * *"
permissions:
contents: read
jobs:
stale:
permissions:
issues: write # for actions/stale to close stale issues
pull-requests: write # for actions/stale to close stale PRs
runs-on: ubuntu-latest
env:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v8
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
days-before-stale: 360
days-before-close: 180
stale-issue-label: "status:stale"
exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request'
stale-pr-label: 'status:stale'
exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request'

View File

@ -1,310 +0,0 @@
name: tests
on:
push:
branches-ignore:
- 'gh-pages'
pull_request:
branches-ignore:
- 'gh-pages'
permissions:
contents: read
jobs:
# Label of the container job
sqlite:
strategy:
matrix:
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh
mysql:
strategy:
matrix:
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.23', '1.24']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mysqladmin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
mariadb:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
services:
mysql:
image: ${{ matrix.dbversion }}
env:
MYSQL_DATABASE: gorm
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
ports:
- 9910:3306
options: >-
--health-cmd "mariadb-admin ping -ugorm -pgorm"
--health-interval 10s
--health-start-period 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
postgres:
image: ${{ matrix.dbversion }}
env:
POSTGRES_PASSWORD: gorm
POSTGRES_USER: gorm
POSTGRES_DB: gorm
TZ: Asia/Shanghai
ports:
- 9920:5432
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
sqlserver:
strategy:
matrix:
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
services:
mssql:
image: mcr.microsoft.com/mssql/server:2022-latest
env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: LoremIpsum86
ports:
- 9930:1433
options: >-
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s
--health-interval 10s
--health-timeout 5s
--health-retries 10
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
steps:
- name: Setup TiDB
uses: Icemap/tidb-action@main
with:
port: 9940
version: ${{matrix.dbversion}}
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

3
.gitignore vendored
View File

@ -2,6 +2,3 @@ TODO*
documents
coverage.txt
_book
.idea
vendor
.vscode

View File

@ -1,19 +0,0 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- gocritic
- gosec
- ineffassign
- misspell
- prealloc
- unconvert
- unparam
- whitespace
formatters:
enable:
- gofumpt
- goimports

View File

@ -1,128 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -2,43 +2,41 @@
The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![go report card](https://goreportcard.com/badge/gorm.io/gorm "go report card")](https://goreportcard.com/report/gorm.io/gorm)
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
[![GoDoc](https://godoc.org/gorm.io/gorm?status.svg)](https://godoc.org/gorm.io/gorm)
## Overview
* Full-Featured ORM
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
* Full-Featured ORM (almost)
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
* Hooks (Before/After Create/Save/Update/Delete/Find)
* Eager loading with `Preload`, `Joins`
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statement Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Preloading (eager loading)
* Transactions
* Composite Primary Key
* SQL Builder
* Auto Migrations
* Logger
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
* Extendable, write Plugins based on GORM callbacks
* Every feature comes with tests
* Developer Friendly
## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License)

View File

@ -1,6 +1,7 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
@ -14,7 +15,6 @@ import (
type Association struct {
DB *DB
Relationship *schema.Relationship
Unscope bool
Error error
}
@ -27,13 +27,10 @@ func (db *DB) Association(column string) *Association {
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
}
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
db.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(db.Statement.Model))
} else {
association.Error = err
}
@ -41,19 +38,34 @@ func (db *DB) Association(column string) *Association {
return association
}
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
tx = association.DB.Model(out)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
association.Error = tx.Find(out, conds...).Error
}
return association.Error
}
@ -65,7 +77,7 @@ func (association *Association) Append(values ...interface{}) error {
association.Error = association.Replace(values...)
}
default:
association.saveAssociation( /*clear*/ false, values...)
association.saveAssociation(false, values...)
}
}
@ -74,30 +86,12 @@ func (association *Association) Append(values ...interface{}) error {
func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
association.saveAssociation(true, values...)
// set old associations's foreign key to null
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
@ -105,33 +99,30 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.Error = association.DB.UpdateColumns(updateMap).Error
}
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
association.DB.UpdateColumns(updateMap)
}
case schema.HasOne, schema.HasMany:
var (
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
@ -146,13 +137,9 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
}
case schema.Many2Many:
var (
@ -176,19 +163,19 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
return ErrorPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
association.Error = tx.Delete(modelValue).Error
tx.Delete(modelValue)
}
}
return association.Error
@ -197,17 +184,18 @@ func (association *Association) Replace(values ...interface{}) error {
func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil {
var (
reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship
primaryFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]interface{}{}
conds []clause.Expression
reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship
primaryFields, foreignFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]interface{}{}
conds []clause.Expression
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil
} else {
@ -217,58 +205,34 @@ func (association *Association) Delete(values ...interface{}) error {
switch rel.Type {
case schema.BelongsTo:
associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany:
model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
)
for _, ref := range rel.References {
@ -285,27 +249,23 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
}
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
if _, zero := rel.Field.ValueOf(data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
@ -313,7 +273,7 @@ func (association *Association) Delete(values ...interface{}) error {
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
@ -321,23 +281,21 @@ func (association *Association) Delete(values ...interface{}) error {
}
}
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
rel.Field.Set(data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
primaryValues[idx], _ = field.ValueOf(fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface())
if rel.JoinTable == nil {
for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -366,8 +324,33 @@ func (association *Association) Clear() error {
func (association *Association) Count() (count int64) {
if association.Error == nil {
association.Error = association.buildCondition().Count(&count).Error
var (
conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE", "LIMIT")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
tx.Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: conds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: conds})
}
association.Error = tx.Count(&count).Error
}
return
}
@ -389,18 +372,14 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
}
}
case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
@ -408,13 +387,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
if clear {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
}
appendToFieldValues := func(ev reflect.Value) {
@ -423,7 +398,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
} else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
}
if elemType.Kind() == reflect.Struct {
@ -437,74 +412,33 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr())
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
}
}
}
selectedSaveColumns := []string{association.Relationship.Name}
omitColumns := []string{}
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, association.Relationship.Name) {
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
columnName = name
}
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selectedSaveColumns = append(selectedSaveColumns, columnName)
} else {
omitColumns = append(omitColumns, columnName)
}
}
}
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
}
}
associationDB := association.DB.Session(&Session{}).Model(nil)
if !association.DB.FullSaveAssociations {
associationDB.Select(selectedSaveColumns)
}
if len(omitColumns) > 0 {
associationDB.Omit(omitColumns...)
}
associationDB = associationDB.Session(&Session{})
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -512,28 +446,24 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
break
}
association.Error = ErrInvalidValueOfLength
association.Error = errors.New("invalid association values, length doesn't match")
return
}
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
@ -542,18 +472,15 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
}
if len(values) > 0 {
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Addr().Interface()).Error
}
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
@ -561,33 +488,3 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
}
}
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}

View File

@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"reflect"
"sort"
"time"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
@ -15,12 +15,12 @@ import (
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
"create": &processor{db: db},
"query": &processor{db: db},
"update": &processor{db: db},
"delete": &processor{db: db},
"row": &processor{db: db},
"raw": &processor{db: db},
},
}
}
@ -32,7 +32,6 @@ type callbacks struct {
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
@ -72,57 +71,28 @@ func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
db = db.executeScopes()
}
func (p *processor) Execute(db *DB) {
curTime := time.Now()
db.RowsAffected = 0
if stmt := db.Statement; stmt != nil {
if stmt.Model == nil {
stmt.Model = stmt.Dest
}
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
db.AddError(err)
}
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
if stmt.Dest != nil {
stmt.ReflectValue = reflect.Indirect(reflect.ValueOf(stmt.Dest))
for stmt.ReflectValue.Kind() == reflect.Ptr {
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(fmt.Errorf("invalid value"))
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(ErrInvalidValue)
}
}
@ -130,26 +100,14 @@ func (p *processor) Execute(db *DB) *DB {
f(db)
}
if stmt.SQL.Len() > 0 {
if stmt := db.Statement; stmt != nil {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
stmt.reinit()
// db.Config.statementPool.Put(stmt)
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
return db
}
func (p *processor) Get(name string) func(*DB) {
@ -187,23 +145,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
@ -226,7 +175,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
@ -234,7 +183,7 @@ func (c *callback) Remove(name string) error {
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
@ -257,36 +206,23 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
names, sorted []string
sortCallback func(*callback) error
)
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
@ -295,16 +231,12 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
}
if c.after != "" { // if defined after callback
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
@ -347,14 +279,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return
}
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -2,7 +2,6 @@ package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -10,443 +9,303 @@ import (
"gorm.io/gorm/utils"
)
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
func SaveBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(obj, pv)
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
objs []reflect.Value
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
}
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(obj) // relation reflect value
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
if isPtr {
elems = reflect.Append(elems, rv)
} else {
elems = reflect.Append(elems, rv.Addr())
}
} else {
setupReferences(obj, rv)
}
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Create(elems.Interface()).Error) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
db.Session(&gorm.Session{}).Create(rv.Interface())
}
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
func SaveAfterAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if _, zero := rel.Field.ValueOf(obj); !zero {
rv := rel.Field.ReflectValueOf(obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
}
}
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(rv, fv)
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
ref.ForeignKey.Set(rv, ref.PrimaryValue)
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(rv); isZero {
elems = reflect.Append(elems, rv)
} else {
db.Session(&gorm.Session{}).Save(rv.Addr().Interface())
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
db.Session(&gorm.Session{}).Create(elems.Interface())
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
ref.ForeignKey.Set(f, fv)
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
ref.ForeignKey.Set(f, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(f); isZero {
db.Session(&gorm.Session{}).Create(f.Interface())
} else {
db.Session(&gorm.Session{}).Save(f.Interface())
}
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(v)
ref.ForeignKey.Set(elem, pv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(elem, ref.PrimaryValue)
}
}
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
} else {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
db.Session(&gorm.Session{}).Save(elem.Addr().Interface())
}
}
joins = reflect.Append(joins, joinValue)
}
}
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
if !isPtr {
elem = elem.Addr()
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if !saveAssociationCheck(db, rel, selectColumns, restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 0)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(obj)
ref.ForeignKey.Set(joinValue, fv)
} else if ref.PrimaryValue != "" {
ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
} else {
fv, _ := ref.PrimaryKey.ValueOf(elem)
ref.ForeignKey.Set(joinValue, fv)
}
}
joins = reflect.Append(joins, joinValue)
}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
if _, isZero := rel.FieldSchema.PrioritizedPrimaryField.ValueOf(elem); isZero {
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
} else {
appendToJoins(v, elem)
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
appendToElems(db.Statement.ReflectValue.Index(i))
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}
if elems.Len() > 0 {
db.Session(&gorm.Session{}).Create(elems.Interface())
for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i))
}
for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
}
if joins.Len() > 0 {
db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface())
}
}
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
func saveAssociationCheck(db *gorm.DB, rel *schema.Relationship, selectColumns map[string]bool, restricted bool) bool {
savable := true
if value, ok := db.Get("gorm:save_association"); ok {
savable = utils.CheckTruth(value)
}
if savable {
if v, ok := selectColumns[rel.Name]; (ok && v) || (!ok && !restricted) {
return true
}
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
FullSaveAssociations: db.FullSaveAssociations,
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if tx.Statement.FullSaveAssociations {
tx = tx.Set("gorm:update_track_time", true)
}
if len(selects) > 0 {
tx = tx.Select(selects)
} else if restricted && len(omits) == 0 {
tx = tx.Omit(clause.Associations)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false

View File

@ -4,19 +4,9 @@ import (
"gorm.io/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
WithReturning bool
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
@ -24,60 +14,37 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
return !db.SkipDefaultTransaction
}
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
db.Callback().Row().Register("gorm:raw", RowQuery)
db.Callback().Raw().Register("gorm:raw", RawExec)
}

View File

@ -1,32 +0,0 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
)
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}

View File

@ -1,269 +1,237 @@
package callbacks
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true
db.AddError(i.BeforeCreate(tx))
}
}
return called
})
}
}
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
return ok
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
isDryRun := !db.DryRun && db.Error == nil
if !isDryRun {
return
}
ok, mode := hasReturning(db, supportReturning)
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
func Create(config *Config) func(db *gorm.DB) {
if config.WithReturning {
return CreateWithReturning
} else {
return func(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok {
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID--
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID++
}
}
case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
} else {
db.AddError(err)
}
}
}
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
}
}
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
func CreateWithReturning(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build("INSERT", "VALUES", "ON CONFLICT")
}
if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 {
db.Statement.WriteString(" RETURNING ")
var (
idx int
fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue))
values = make([]interface{}, len(sch.FieldsWithDefaultDBValue))
)
for dbName, field := range sch.FieldsWithDefaultDBValue {
if idx != 0 {
db.Statement.WriteByte(',')
}
fields[idx] = field
db.Statement.WriteQuoted(dbName)
idx++
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
defer rows.Close()
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for rows.Next() {
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface()
}
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
db.RowsAffected++
}
case reflect.Struct:
for idx, field := range fields {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
}
if rows.Next() {
db.RowsAffected++
err = rows.Scan(values...)
}
}
}
if err != nil {
db.AddError(err)
}
}
} else if !db.DryRun {
if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true
db.AddError(i.AfterCreate(tx))
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
return ConvertMapToValuesForCreate(stmt, value)
case []map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
return ConvertSliceOfMapToValuesForCreate(stmt, value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time")
isZero bool
values = clause.Values{}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
curTime = stmt.DB.NowFunc()
isZero = false
)
stmt.Settings.Delete("gorm:update_track_time")
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
if stmt.Schema.FieldsWithDefaultDBValue[db] == nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
@ -271,62 +239,43 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len()
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[string][]interface{}{}
for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
field.Set(rv, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
field.Set(rv, curTime)
values.Values[i][idx], _ = field.ValueOf(rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[db]) == 0 {
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
}
defaultValueFieldsHavingValue[field][i] = rvOfvalue
defaultValueFieldsHavingValue[db][i] = v
}
}
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
for db, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: db})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], clause.Expr{SQL: "DEFAULT"})
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
@ -334,79 +283,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
}
for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: db})
values.Values[0] = append(values.Values[0], v)
}
}
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if len(onConflict.DoUpdates) == 0 {
onConflict.DoNothing = true
}
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {
for _, field := range stmt.Schema.PrimaryFields {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
}
}
stmt.AddClause(onConflict)
}
}
}
return values
return values
}
}

View File

@ -1,71 +0,0 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -2,143 +2,60 @@ package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
}
return false
})
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
for column, v := range selectColumns {
if !v {
continue
}
rel, ok := db.Statement.Schema.Relationships.Relations[column]
if !ok {
continue
}
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
func Delete(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
@ -146,50 +63,49 @@ func Delete(config *Config) func(db *gorm.DB) {
}
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build(db.Statement.BuildClauses...)
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build("DELETE", "FROM", "WHERE")
}
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
}
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}
}
return false
})
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

View File

@ -1,38 +1,80 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func SelectAndOmitColumns(stmt *gorm.Statement, requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
if column == "*" {
notRestricted = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
}
break
}
if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else {
results[column] = true
}
}
// omit columns
for _, omit := range stmt.Omits {
if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false
}
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.Fields {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}
// ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
columns := make([]string, 0, len(mapValue))
selectColumns, restricted := SelectAndOmitColumns(stmt, true, false)
keys := make([]string, 0, len(mapValue))
for k := range mapValue {
var keys []string
for k, _ := range mapValue {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
value := mapValue[k]
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
columns = append(columns, k)
values.Values[0] = append(values.Values[0], value)
}
}
@ -41,26 +83,16 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValues))
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var (
result = make(map[string][]interface{}, len(mapValues))
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
columns = []string{}
result = map[string][]interface{}{}
selectColumns, restricted = SelectAndOmitColumns(stmt, true, false)
)
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
if _, ok := result[k]; !ok {
@ -78,75 +110,13 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] {
if len(values.Values[i]) == 0 {
if i == 0 {
values.Values[i] = make([]interface{}, len(columns))
}
values.Values[i][idx] = v
}
}
return
}
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

View File

@ -1,157 +0,0 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

11
callbacks/interface.go Normal file
View File

@ -0,0 +1,11 @@
package callbacks
import "gorm.io/gorm"
type beforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type beforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}

View File

@ -1,39 +0,0 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

View File

@ -1,10 +1,7 @@
package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -12,179 +9,11 @@ import (
"gorm.io/gorm/utils"
)
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
var (
reflectValue = tx.Statement.ReflectValue
reflectValue = db.Statement.ReflectValue
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{})
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
@ -193,13 +22,13 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
inlineConds []interface{}
)
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
if len(rels) > 1 {
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
}
if rel.JoinTable != nil {
var joinForeignFields, joinRelForeignFields []*schema.Field
var joinForeignKeys []string
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
@ -214,28 +43,25 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return nil
return
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
fieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
joinFieldValues[idx], _ = field.ValueOf(joinResults.Index(i))
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
@ -244,7 +70,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
@ -260,75 +86,35 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
if len(foreignValues) == 0 {
return nil
return
}
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...)
fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
fieldValues[idx], _ = field.ValueOf(elem)
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
reflectFieldValue := rel.Field.ReflectValueOf(data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
@ -336,16 +122,14 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
rel.Field.Set(data, reflectResults.Index(i).Interface())
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
} else {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
}
}
}
}
return tx.Error
}

View File

@ -3,312 +3,211 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if !db.DryRun && db.Error == nil {
if db.Statement.SQL.String() == "" {
BuildQuerySQL(db)
}
if !db.DryRun {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
defer rows.Close()
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
gorm.Scan(rows, db, false)
}
}
}
func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
clauseSelect := clause.Select{}
if db.Statement.ReflectValue.Kind() == reflect.Struct {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
if len(db.Statement.Selects) > 0 {
for _, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: name,
Raw: true,
})
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: f.DBName,
})
} else {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Name: name,
Raw: true,
})
}
}
}
if len(db.Statement.Selects) > 0 {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
} else {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
}
}
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
// inline joins
if len(db.Statement.Joins) != 0 {
joins := []clause.Join{}
if len(db.Statement.Selects) == 0 {
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: db.Statement.Table,
Name: dbName,
})
}
}
// inline joins
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
for name, conds := range db.Statement.Joins {
if db.Statement.Schema == nil {
joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds},
})
} else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok {
tableAliasName := relation.Name
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = guessNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
for _, s := range relation.FieldSchema.DBNames {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: tableAliasName + "__" + s,
})
}
}
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
var exprs []clause.Expression
for _, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs = append(exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
})
} else {
if ref.PrimaryValue == "" {
exprs = append(exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
})
} else {
exprs = append(exprs, clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
}
joins = append(joins, clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
})
} else {
joins = append(joins, clause.Join{
Expression: clause.Expr{SQL: name, Vars: conds},
})
}
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build(db.Statement.BuildClauses...)
db.Statement.AddClause(clause.From{Joins: joins})
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
if db.Error == nil {
if len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{}
for name := range db.Statement.Preloads {
preloadFields := strings.Split(name, ".")
for idx := range preloadFields {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
}
}
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
}
preloadNames := make([]string, len(preloadMap))
idx := 0
for key := range preloadMap {
preloadNames[idx] = key
idx++
}
sort.Strings(preloadNames)
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
return
}
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
}
preload(db, rels, db.Statement.Preloads[name])
}
}
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}
}
return false
})
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

View File

@ -5,18 +5,12 @@ import (
)
func RawExec(db *gorm.DB) {
if db.Error == nil && !db.DryRun {
if db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
} else {
db.RowsAffected, _ = result.RowsAffected()
}
}
}

View File

@ -6,18 +6,14 @@ import (
func RowQuery(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
if db.Statement.SQL.String() == "" {
BuildQuerySQL(db)
}
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows")
if _, ok := db.Get("rows"); ok {
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}

View File

@ -5,28 +5,21 @@ import (
)
func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction && db.Error == nil {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true)
} else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil
} else {
db.Error = tx.Error
}
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
tx.InstanceSet("gorm:started_transaction", true)
} else {
tx.Error = nil
}
}
func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error != nil {
db.Rollback()
} else {
db.Commit()
}
db.Statement.ConnPool = db.ConnPool
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error == nil {
db.Commit()
} else {
db.Rollback()
}
db.Statement.ConnPool = db.ConnPool
}
}

View File

@ -7,11 +7,10 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
if db.Error == nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
@ -21,7 +20,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
}
}
}
@ -29,117 +28,113 @@ func SetupUpdateReflectValue(db *gorm.DB) {
}
}
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
db.AddError(i.BeforeUpdate(tx))
}
}
return ok
}
return called
})
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
// Update update hook
func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
func Update(db *gorm.DB) {
if db.Error == nil {
if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
if db.Statement.SQL.String() == "" {
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return
}
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else {
return
}
db.Statement.Build(db.Statement.BuildClauses...)
db.Statement.Build("UPDATE", "SET", "WHERE")
}
checkMissingWhereConditions(db)
if _, ok := db.Statement.Clauses["WHERE"]; !ok {
db.AddError(gorm.ErrMissingWhereClause)
return
}
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
if err == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(err)
}
}
}
}
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
tx := db.Session(&gorm.Session{})
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
db.AddError(i.AfterUpdate(tx))
}
}
return ok
}
if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
// ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
selectColumns, restricted = SelectAndOmitColumns(stmt, false, true)
assignValue func(field *schema.Field, value interface{})
)
@ -147,15 +142,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
field.Set(stmt.ReflectValue.Index(i), value)
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue, value)
field.Set(stmt.ReflectValue, value)
}
}
default:
@ -168,146 +161,94 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
updatingValue = updatingValue.Elem()
}
switch value := updatingValue.Interface().(type) {
case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value))
var keys []string
for k, _ := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k])
}
} else if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]})
}
}
if !stmt.DisableUpdateTime {
for _, field := range stmt.Schema.FieldsByDBName {
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
now := stmt.DB.NowFunc()
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
assignValue(field, now)
}
}
}
default:
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, field := range stmt.Schema.FieldsByDBName {
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.DisableUpdateTime {
if field.AutoUpdateTime > 0 {
value = stmt.DB.NowFunc()
isZero = false
}
}
if ok || !isZero {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignValue(field, value)
}
}
} else {
if value, isZero := field.ValueOf(updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
}
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var isZero bool
for i := 0; i < size; i++ {
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
var priamryKeyExprs []clause.Expression
for i := 0; i < stmt.ReflectValue.Len(); i++ {
var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
var notZero bool
for idx, field := range stmt.Schema.PrimaryFields {
value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
notZero = notZero || !isZero
}
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
if notZero {
priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...))
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}})
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
switch value := updatingValue.Interface().(type) {
case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value))
keys := make([]string, 0, len(value))
for k := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k])
}
continue
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
now := stmt.DB.NowFunc()
assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
}
}
}
default:
updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
}
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
}
isZero = false
}
if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
}
}
} else {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
return
}

View File

@ -2,7 +2,6 @@ package gorm
import (
"fmt"
"regexp"
"strings"
"gorm.io/gorm/clause"
@ -10,11 +9,10 @@ import (
)
// Model specify the model you would like to run db operations
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello")
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
@ -22,19 +20,6 @@ func (db *DB) Model(value interface{}) (tx *DB) {
}
// Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
@ -42,73 +27,25 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(whereConds[0], whereConds[1:]...)})
}
return
}
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
func (db *DB) Table(name string) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
return
}
// Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
tx.Statement.Table = name
return
}
// Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
@ -127,24 +64,12 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v}
fields := strings.FieldsFunc(v, utils.IsChar)
// normal field names
if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") {
tx.Statement.Selects = fields
for _, arg := range args {
switch arg := arg.(type) {
@ -154,17 +79,15 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
} else {
tx.Statement.AddClause(clause.Select{
Expression: clause.Expr{SQL: v, Vars: args},
})
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
@ -178,182 +101,99 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar)
} else {
tx.Statement.Omits = columns
}
return
}
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: conds})
}
return
}
// Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
}
return
}
// Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
if conds := tx.Statement.BuildCondtion(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(conds...)}})
}
return
}
// Joins specify Joins conditions
//
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
if tx.Statement.Joins == nil {
tx.Statement.Joins = map[string][]interface{}{}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
tx.Statement.Joins[query] = args
return
}
// Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
Columns: []clause.Column{{Name: name}},
})
return
}
// Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondition(query, args...),
Having: tx.Statement.BuildCondtion(query, args...),
})
return
}
// Order specify order when retrieving records from database
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
// Order specify order when retrieve records from database
// db.Order("name DESC")
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: v, Raw: true},
}},
})
}
default:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: fmt.Sprint(value), Raw: true},
}},
})
}
return
}
// Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit})
tx.Statement.AddClause(clause.Limit{Limit: limit})
return
}
// Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset})
@ -361,37 +201,26 @@ func (db *DB) Offset(offset int) (tx *DB) {
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs {
db = f(db)
}
return db
}
// Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Preloads == nil {
@ -401,57 +230,18 @@ func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
return
}
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.attrs = attrs
return
}
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.assigns = attrs
return
}
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
@ -461,11 +251,6 @@ func (db *DB) Unscoped() (tx *DB) {
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
return
}

View File

@ -29,12 +29,10 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{
clause.Select{},
clause.From{},
clause.Select{}, clause.From{},
clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
@ -44,7 +42,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: &limit10, Offset: 20},
clause.Limit{Limit: 10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
}

View File

@ -7,7 +7,7 @@ type Interface interface {
MergeClause(*Clause)
}
// ClauseBuilder clause builder, allows to customize how to build clause
// ClauseBuilder clause builder, allows to custmize how to build clause
type ClauseBuilder func(Clause, Builder)
type Writer interface {
@ -18,54 +18,48 @@ type Writer interface {
// Builder builder interface
type Builder interface {
Writer
WriteQuoted(field interface{})
WriteQuoted(field interface{}) error
AddVar(Writer, ...interface{})
AddError(error) error
}
// Clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression
AfterNameExpression Expression
AfterExpression Expression
Expression Expression
Builder ClauseBuilder
Name string // WHERE
Priority float64
BeforeExpressions []Expression
AfterNameExpressions []Expression
AfterExpressions []Expression
Expression Expression
Builder ClauseBuilder
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
if c.BeforeExpression != nil {
c.BeforeExpression.Build(builder)
builder.WriteByte(' ')
}
} else {
builders := c.BeforeExpressions
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
builders = append(builders, Expr{SQL: c.Name})
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
builders = append(builders, c.AfterNameExpressions...)
if c.Expression != nil {
builders = append(builders, c.Expression)
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
for idx, expr := range append(builders, c.AfterExpressions...) {
if idx != 0 {
builder.WriteByte(' ')
}
expr.Build(builder)
}
}
}
const (
PrimaryKey string = "~~~py~~~" // primary key
CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
PrimaryKey string = "@@@priamry_key@@@"
CurrentTable string = "@@@table@@@"
)
var (

View File

@ -1,9 +1,7 @@
package clause
import (
"database/sql"
"database/sql/driver"
"go/ast"
"reflect"
)
@ -19,9 +17,8 @@ type NegationExpressionBuilder interface {
// Expr raw expression
type Expr struct {
SQL string
Vars []interface{}
WithoutParentheses bool
SQL string
Vars []interface{}
}
// Build build raw expression
@ -32,130 +29,18 @@ func (expr Expr) Build(builder Builder) {
)
for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx {
if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
}
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]interface{}, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
afterParenthesis = false
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
if v == '?' {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
default:
builder.AddVar(builder, expr.Vars[idx])
@ -166,8 +51,6 @@ func (expr NamedExpr) Build(builder Builder) {
}
idx++
} else if inName {
name = append(name, v)
} else {
if v == '(' {
afterParenthesis = true
@ -177,15 +60,6 @@ func (expr NamedExpr) Build(builder Builder) {
builder.WriteByte(v)
}
}
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
}
// IN Whether a value is within a set of values
@ -201,13 +75,8 @@ func (in IN) Build(builder Builder) {
case 0:
builder.WriteString(" IN (NULL)")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" = ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
builder.WriteString(" = ")
builder.AddVar(builder, in.Values...)
default:
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
@ -216,19 +85,14 @@ func (in IN) Build(builder Builder) {
}
func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IS NOT NULL")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
builder.WriteQuoted(in.Column)
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values...)
default:
builder.WriteQuoted(in.Column)
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
@ -244,33 +108,16 @@ type Eq struct {
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
if eq.Value == nil {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq(eq).Build(builder)
Neq{eq.Column, eq.Value}.Build(builder)
}
// Neq not equal to for where
@ -279,29 +126,16 @@ type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
if neq.Value == nil {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq(neq).Build(builder)
Eq{neq.Column, neq.Value}.Build(builder)
}
// Gt greater than for where
@ -314,7 +148,7 @@ func (gt Gt) Build(builder Builder) {
}
func (gt Gt) NegationBuild(builder Builder) {
Lte(gt).Build(builder)
Lte{gt.Column, gt.Value}.Build(builder)
}
// Gte greater than or equal to for where
@ -327,7 +161,7 @@ func (gte Gte) Build(builder Builder) {
}
func (gte Gte) NegationBuild(builder Builder) {
Lt(gte).Build(builder)
Lt{gte.Column, gte.Value}.Build(builder)
}
// Lt less than for where
@ -340,7 +174,7 @@ func (lt Lt) Build(builder Builder) {
}
func (lt Lt) NegationBuild(builder Builder) {
Gte(lt).Build(builder)
Gte{lt.Column, lt.Value}.Build(builder)
}
// Lte less than or equal to for where
@ -353,7 +187,7 @@ func (lte Lte) Build(builder Builder) {
}
func (lte Lte) NegationBuild(builder Builder) {
Gt(lte).Build(builder)
Gt{lte.Column, lte.Value}.Build(builder)
}
// Like whether string matches regular expression
@ -370,16 +204,3 @@ func (like Like) NegationBuild(builder Builder) {
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
}
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value interface{}) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}

View File

@ -1,9 +1,7 @@
package clause_test
import (
"database/sql"
"fmt"
"reflect"
"sync"
"testing"
@ -35,203 +33,3 @@ func TestExpr(t *testing.T) {
})
}
}
func TestNamedExpr(t *testing.T) {
type Base struct {
Name2 string
}
type NamedArgument struct {
Name1 string
Base
}
results := []struct {
SQL string
Result string
Vars []interface{}
ExpectedVars []interface{}
}{{
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}},
Result: "create table `users` (`id` int, `name` text)",
}, {
SQL: "name1 = @name AND name2 = @name",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name AND name2 = @@name",
Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}},
Result: "name1 = ? AND name2 = @@name",
ExpectedVars: []interface{}{"jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}},
Result: "name1 = ? AND name2 = ? AND name3 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist",
Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist",
Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}},
Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist",
ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"},
}, {
SQL: "create table ? (? ?, ? ?)",
Vars: []interface{}{},
Result: "create table ? (? ?, ? ?)",
}, {
SQL: "name1 = @name AND name2 = @name;",
Vars: []interface{}{sql.Named("name", "jinzhu")},
Result: "name1 = ? AND name2 = ?;",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r\n AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r\n AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "name1 = @name1\r AND name2 = @name2",
Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}},
Result: "name1 = ?\r AND name2 = ?",
ExpectedVars: []interface{}{"jinzhu", "jinzhu"},
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col"}},
Result: "`table`.`col`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}},
Result: "table.col",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}},
Result: "table.id",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}},
Result: "`table`.`col` AS `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}},
Result: "table.col AS alias",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}},
Result: "`table` `alias`",
}, {
SQL: "?",
Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}},
Result: "table alias",
}}
for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}
func TestExpression(t *testing.T) {
column := "column-name"
results := []struct {
Expressions []clause.Expression
ExpectedVars []interface{}
Result string
}{{
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: "column-value"},
},
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` = ?",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: nil},
clause.Eq{Column: column, Value: (*string)(nil)},
clause.Eq{Column: column, Value: (*int)(nil)},
clause.Eq{Column: column, Value: (*bool)(nil)},
clause.Eq{Column: column, Value: (interface{})(nil)},
clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}},
},
Result: "`column-name` IS NULL",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: "column-value"},
},
ExpectedVars: []interface{}{"column-value"},
Result: "`column-name` <> ?",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: nil},
clause.Neq{Column: column, Value: (*string)(nil)},
clause.Neq{Column: column, Value: (*int)(nil)},
clause.Neq{Column: column, Value: (*bool)(nil)},
clause.Neq{Column: column, Value: (interface{})(nil)},
},
Result: "`column-name` IS NOT NULL",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Neq{Column: column, Value: []string{"a", "b"}},
},
ExpectedVars: []interface{}{"a", "b"},
Result: "`column-name` NOT IN (?,?)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: column, Value: []string{}},
},
Result: "`column-name` IN (NULL)",
}, {
Expressions: []clause.Expression{
clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`id`) = ?",
}, {
Expressions: []clause.Expression{
clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100},
},
ExpectedVars: []interface{}{100},
Result: "SUM(`users`.`id`) >= ?",
}}
for idx, result := range results {
for idy, expression := range result.Expressions {
t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
expression.Build(stmt)
if stmt.SQL.String() != result.Result {
t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
}
if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
}
})
}
}
}

View File

@ -33,5 +33,9 @@ func (from From) Build(builder Builder) {
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(From); ok {
from.Tables = append(v.Tables, from.Tables...)
from.Joins = append(v.Joins, from.Joins...)
}
clause.Expression = from
}

View File

@ -38,16 +38,6 @@ func TestFrom(t *testing.T) {
[]clause.Interface{
clause.Select{}, clause.From{
Tables: []clause.Table{{Name: "users"}},
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
},
},
},
}, clause.From{
Joins: []clause.Join{
{
Type: clause.InnerJoin,
@ -61,9 +51,19 @@ func TestFrom(t *testing.T) {
Using: []string{"company_name"},
},
},
}, clause.From{
Joins: []clause.Join{
{
Type: clause.RightJoin,
Table: clause.Table{Name: "profiles"},
ON: clause.Where{
[]clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}},
},
},
},
},
},
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil,
"SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`) RIGHT JOIN `profiles` ON `profiles`.`email` = `users`.`email`", nil,
},
}

View File

@ -30,19 +30,8 @@ func (groupBy GroupBy) Build(builder Builder) {
// MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok {
copiedColumns := make([]Column, len(v.Columns))
copy(copiedColumns, v.Columns)
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
copiedHaving := make([]Expression, len(v.Having))
copy(copiedHaving, v.Having)
groupBy.Having = append(copiedHaving, groupBy.Having...)
groupBy.Columns = append(v.Columns, groupBy.Columns...)
groupBy.Having = append(v.Having, groupBy.Having...)
}
clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
}

View File

@ -18,8 +18,7 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "role"}},
Having: []clause.Expression{clause.Eq{"role", "admin"}},
}},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?",
[]interface{}{"admin"},
"SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{
@ -29,8 +28,7 @@ func TestGroupBy(t *testing.T) {
Columns: []clause.Column{{Name: "gender"}},
Having: []clause.Expression{clause.Neq{"gender", "U"}},
}},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?",
[]interface{}{"admin", "U"},
"SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"},
},
}

View File

@ -1,41 +1,15 @@
package clause
import "gorm.io/gorm/utils"
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin JoinType = "INNER"
LeftJoin JoinType = "LEFT"
RightJoin JoinType = "RIGHT"
InnerJoin = "INNER"
LeftJoin = "LEFT"
RightJoin = "RIGHT"
)
type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
// Join join clause for from
type Join struct {
Type JoinType
Table Table
@ -44,12 +18,6 @@ type Join struct {
Expression Expression
}
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)

View File

@ -1,101 +0,0 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -1,8 +1,10 @@
package clause
import "strconv"
// Limit limit clause
type Limit struct {
Limit *int
Limit int
Offset int
}
@ -13,16 +15,14 @@ func (limit Limit) Name() string {
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
if limit.Limit > 0 {
builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit)
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
builder.WriteString(strconv.Itoa(limit.Limit))
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
}
}
@ -31,8 +31,10 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
if limit.Limit == 0 && v.Limit > 0 {
limit.Limit = v.Limit
} else if limit.Limit < 0 {
limit.Limit = 0
}
if limit.Offset == 0 && v.Offset > 0 {

View File

@ -8,10 +8,6 @@ import (
)
func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct {
Clauses []clause.Interface
Result string
@ -19,56 +15,26 @@ func TestLimit(t *testing.T) {
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: &limit10,
Limit: 10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{20},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
"SELECT * FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit10},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit50, 30},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
},
}

View File

@ -1,11 +1,8 @@
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type For struct {
Lockings []Locking
}
type Locking struct {
Strength string
@ -14,25 +11,38 @@ type Locking struct {
}
// Name where clause name
func (locking Locking) Name() string {
func (f For) Name() string {
return "FOR"
}
// Build build where clause
func (locking Locking) Build(builder Builder) {
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
func (f For) Build(builder Builder) {
for idx, locking := range f.Lockings {
if idx > 0 {
builder.WriteByte(' ')
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.WriteString(locking.Options)
builder.WriteString("FOR ")
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.WriteString(locking.Options)
}
}
}
// MergeClause merge order by clauses
func (locking Locking) MergeClause(clause *Clause) {
clause.Expression = locking
func (f For) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(For); ok {
f.Lockings = append(v.Lockings, f.Lockings...)
}
clause.Expression = f
}

View File

@ -7,27 +7,31 @@ import (
"gorm.io/gorm/clause"
)
func TestLocking(t *testing.T) {
func TestFor(t *testing.T) {
results := []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}},
}},
"SELECT * FROM `users` FOR UPDATE", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
[]clause.Interface{clause.Select{}, clause.From{}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE"}, {Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
}, clause.For{
Lockings: []clause.Locking{{Strength: "UPDATE", Options: "NOWAIT"}},
}},
"SELECT * FROM `users` FOR UPDATE FOR SHARE OF `users` FOR UPDATE NOWAIT", nil,
},
}

View File

@ -1,13 +1,10 @@
package clause
type OnConflict struct {
Columns []Column
Where Where
TargetWhere Where
OnConstraint string
DoNothing bool
DoUpdates Set
UpdateAll bool
Columns []Column
Where Where
DoNothing bool
DoUpdates Set
}
func (OnConflict) Name() string {
@ -16,27 +13,15 @@ func (OnConflict) Name() string {
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
if len(onConflict.Columns) > 0 {
builder.WriteQuoted(onConflict.Columns) // FIXME columns
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString("WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
if onConflict.DoNothing {
@ -45,12 +30,6 @@ func (onConflict OnConflict) Build(builder Builder) {
builder.WriteString("DO UPDATE SET ")
onConflict.DoUpdates.Build(builder)
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
}
// MergeClause merge onConflict clauses

View File

@ -7,8 +7,7 @@ type OrderByColumn struct {
}
type OrderBy struct {
Columns []OrderByColumn
Expression Expression
Columns []OrderByColumn
}
// Name where clause name
@ -18,18 +17,14 @@ func (orderBy OrderBy) Name() string {
// Build build where clause
func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
for idx, column := range orderBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column.Column)
if column.Desc {
builder.WriteString(" DESC")
}
builder.WriteQuoted(column.Column)
if column.Desc {
builder.WriteString(" DESC")
}
}
}
@ -45,9 +40,7 @@ func (orderBy OrderBy) MergeClause(clause *Clause) {
}
}
copiedColumns := make([]OrderByColumn, len(v.Columns))
copy(copiedColumns, v.Columns)
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
orderBy.Columns = append(v.Columns, orderBy.Columns...)
}
clause.Expression = orderBy

View File

@ -39,15 +39,6 @@ func TestOrderBy(t *testing.T) {
},
"SELECT * FROM `users` ORDER BY `name`", nil,
},
{
[]clause.Interface{
clause.Select{}, clause.From{}, clause.OrderBy{
Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
},
},
"SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)",
[]interface{}{1, 2, 3},
},
}
for idx, result := range results {

View File

@ -11,27 +11,20 @@ func (returning Returning) Name() string {
// Build build where clause
func (returning Returning) Build(builder Builder) {
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
} else {
builder.WriteByte('*')
builder.WriteQuoted(column)
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
if v.Columns != nil {
returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
}
clause.Expression = returning
}

View File

@ -26,22 +26,6 @@ func TestReturning(t *testing.T) {
}},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
}
for idx, result := range results {

View File

@ -2,7 +2,6 @@ package clause
// Select select attrs when querying, updating, creating
type Select struct {
Distinct bool
Columns []Column
Expression Expression
}
@ -13,10 +12,6 @@ func (s Select) Name() string {
func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
@ -30,30 +25,8 @@ func (s Select) Build(builder Builder) {
func (s Select) MergeClause(clause *Clause) {
if s.Expression != nil {
if s.Distinct {
if expr, ok := s.Expression.(Expr); ok {
expr.SQL = "DISTINCT " + expr.SQL
clause.Expression = expr
return
}
}
clause.Expression = s.Expression
} else {
clause.Expression = s
}
}
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}

View File

@ -31,37 +31,6 @@ func TestSelect(t *testing.T) {
}, clause.From{}},
"SELECT `name` FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}},
clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}},
clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}},
},
},
}, clause.From{}},
"SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil,
},
{
[]clause.Interface{clause.Select{
Expression: clause.CommaExpression{
Exprs: []clause.Expression{
clause.Expr{
SQL: "? as name",
Vars: []interface{}{
clause.Eq{
Column: clause.Column{Name: "age"},
Value: 18,
},
},
},
},
},
}, clause.From{}},
"SELECT `age` = ? as name FROM `users`",
[]interface{}{18},
},
}
for idx, result := range results {

View File

@ -1,7 +1,5 @@
package clause
import "sort"
type Set []Assignment
type Assignment struct {
@ -24,37 +22,13 @@ func (set Set) Build(builder Builder) {
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteQuoted(PrimaryColumn)
builder.WriteByte('=')
builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteQuoted(PrimaryColumn)
}
}
// MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) {
copiedAssignments := make([]Assignment, len(set))
copy(copiedAssignments, set)
clause.Expression = Set(copiedAssignments)
}
func Assignments(values map[string]interface{}) Set {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}
func AssignmentColumns(values []string) Set {
assignments := make([]Assignment, len(values))
for idx, value := range values {
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
}
return assignments
clause.Expression = set
}

View File

@ -2,8 +2,6 @@ package clause_test
import (
"fmt"
"sort"
"strings"
"testing"
"gorm.io/gorm/clause"
@ -20,8 +18,7 @@ func TestSet(t *testing.T) {
clause.Update{},
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
},
"UPDATE `users` SET `users`.`id`=?",
[]interface{}{1},
"UPDATE `users` SET `users`.`id`=?", []interface{}{1},
},
{
[]clause.Interface{
@ -29,8 +26,7 @@ func TestSet(t *testing.T) {
clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}),
clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}),
},
"UPDATE `users` SET `name`=?",
[]interface{}{"jinzhu"},
"UPDATE `users` SET `name`=?", []interface{}{"jinzhu"},
},
}
@ -40,20 +36,3 @@ func TestSet(t *testing.T) {
})
}
}
func TestAssignments(t *testing.T) {
set := clause.Assignments(map[string]interface{}{
"name": "jinzhu",
"age": 18,
})
assignments := []clause.Assignment(set)
sort.Slice(assignments, func(i, j int) bool {
return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0
})
if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 {
t.Errorf("invalid assignments, got %v", assignments)
}
}

View File

@ -21,8 +21,7 @@ func TestValues(t *testing.T) {
Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}},
},
},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)",
[]interface{}{"jinzhu", 18, "josh", 1},
"INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1},
},
}

View File

@ -1,14 +1,5 @@
package clause
import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause
type Where struct {
Exprs []Expression
@ -21,15 +12,9 @@ func (where Where) Name() string {
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
if v, ok := expr.(OrConditions); (!ok && expr != nil) || len(v.Exprs) > 1 {
if idx != 0 {
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
}
@ -37,64 +22,27 @@ func (where Where) Build(builder Builder) {
}
}
buildExprs(where.Exprs, builder, AndWithSpace)
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(OrWithSpace)
} else {
builder.WriteString(joinCond)
}
}
if len(exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
for idx, expr := range where.Exprs {
if expr != nil {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(" OR ")
} else {
builder.WriteString(" AND ")
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteByte('(')
expr.Build(builder)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder)
}
}
return
}
// MergeClause merge where clauses
func (where Where) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(Where); ok {
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], where.Exprs)
where.Exprs = exprs
where.Exprs = append(w.Exprs, where.Exprs...)
}
clause.Expression = where
@ -104,13 +52,6 @@ func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs}
}
@ -121,10 +62,15 @@ type AndConditions struct {
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, AndWithSpace)
}
for idx, c := range and.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
}
c.Build(builder)
}
if len(and.Exprs) > 1 {
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, AndWithSpace)
}
}
@ -142,10 +88,15 @@ type OrConditions struct {
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, OrWithSpace)
}
for idx, c := range or.Exprs {
if idx > 0 {
builder.WriteString(" OR ")
}
c.Build(builder)
}
if len(or.Exprs) > 1 {
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, OrWithSpace)
}
}
@ -153,11 +104,6 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
@ -166,80 +112,22 @@ type NotConditions struct {
}
func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(" AND ")
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString(" NOT ")
c.Build(builder)
}
}
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}

View File

@ -17,29 +17,25 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?",
[]interface{}{"1", 18, "jinzhu"},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
Exprs: []clause.Expression{clause.Or(), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?",
[]interface{}{"1", "jinzhu", 18},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?",
[]interface{}{"1", "jinzhu"},
"SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -47,8 +43,7 @@ func TestWhere(t *testing.T) {
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
"SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
@ -56,78 +51,7 @@ func TestWhere(t *testing.T) {
}, clause.Where{
Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)",
[]interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}},
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
[]interface{}{18, "jinzhu"},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?",
[]interface{}{"1", 18, 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}),
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))},
}},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"},
},
}

View File

@ -1,3 +1,4 @@
package clause
type With struct{}
type With struct {
}

View File

@ -2,53 +2,25 @@ package gorm
import (
"errors"
"gorm.io/gorm/logger"
)
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = logger.ErrRecordNotFound
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("invalid transaction")
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrUnaddressable unaddressable value
ErrUnaddressable = errors.New("using unaddressable value")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPrimaryKeyRequired primary keys required
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
// ErrPtrStructSupported only ptr of struct supported
ErrPtrStructSupported = errors.New("only ptr of struct supported")
// ErrorPrimaryKeyRequired primary keys required
ErrorPrimaryKeyRequired = errors.New("primary key required")
)

View File

@ -1,394 +1,183 @@
package gorm
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/maphash"
"reflect"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Create inserts value, returning the inserted data's primary key in value's id
// Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var rowsAffected int64
tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error {
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
ends = reflectLen
}
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
}
tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
}
tx.callbacks.Create().Execute(tx)
return
}
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
where := clause.Where{Exprs: make([]clause.Expression, len(tx.Statement.Schema.PrimaryFields))}
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
tx.AddError(ErrPtrStructSupported)
case reflect.Struct:
for idx, pf := range tx.Statement.Schema.PrimaryFields {
if pv, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
where.Exprs[idx] = clause.Eq{Column: pf.DBName, Value: pv}
return
}
}
}
fallthrough
default:
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
tx.Statement.AddClause(where)
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
tx.callbacks.Update().Execute(tx)
return
}
// First finds the first record ordered by primary key, matching given conditions conds
// First find first record that match given conditions, order by primary key
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
tx.callbacks.Query().Execute(tx)
return
}
// Take finds the first record returned by the database in no specified order, matching given conditions conds
// Take return a record that match given conditions, the order will depend on the database implementation
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
tx = db.getInstance().Limit(1)
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
tx.callbacks.Query().Execute(tx)
return
}
// Last finds the last record ordered by primary key, matching given conditions conds
// Last find last record that match given conditions, order by primary key
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
tx = db.getInstance().Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
tx.callbacks.Query().Execute(tx)
return
}
// Find finds all records matching given conditions conds
// Find find records that match given conditions
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
tx.callbacks.Query().Execute(tx)
return
}
// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
}
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
tx.RowsAffected = rowsAffected
return tx
}
func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := db.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
db.assignInterfacesToValue(andCond.Exprs)
func (tx *DB) assignExprsToValue(exprs []clause.Expression) {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := tx.Statement.Schema.LookUpField(column); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
case clause.Column:
if field := tx.Statement.Schema.LookUpField(column.Name); field != nil {
field.Set(tx.Statement.ReflectValue, eq.Value)
}
} else if len(values) > 0 {
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
default:
}
}
}
}
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
tx = db.getInstance()
if tx = tx.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
tx.assignExprsToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
}
tx.Error = nil
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
}
return
}
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if err := tx.First(dest, conds...).Error; errors.Is(err, ErrRecordNotFound) {
tx.Error = nil
result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
result.assignInterfacesToValue(where.Exprs)
tx.assignExprsToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
if len(tx.Statement.attrs) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.attrs[0], tx.Statement.attrs[1:]...)
tx.assignExprsToValue(exprs)
}
// initialize with attrs, conds
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
tx.assignExprsToValue(exprs)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
} else if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
default:
}
}
}
@ -396,385 +185,178 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Model(dest).Updates(assigns)
}
return tx
return
}
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx)
tx.callbacks.Update().Execute(tx)
return
}
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
return tx.callbacks.Update().Execute(tx)
tx.callbacks.Update().Execute(tx)
return
}
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
tx.Statement.DisableUpdateTime = true
tx.callbacks.Update().Execute(tx)
return
}
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
tx.Statement.DisableUpdateTime = true
tx.callbacks.Update().Execute(tx)
return
}
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondtion(conds[0], conds[1:]...)})
}
tx.Statement.Dest = value
return tx.callbacks.Delete().Execute(tx)
tx.callbacks.Delete().Execute(tx)
return
}
func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance()
if s, ok := tx.Statement.Clauses["SELECT"].Expression.(clause.Select); !ok || len(s.Columns) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
}
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
}
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(*)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else if dbName != "*" {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
}
tx.Statement.AddClause(clause.Select{Expression: expr})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx = tx.callbacks.Query().Execute(tx)
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected
tx.callbacks.Query().Execute(tx)
if db.RowsAffected != 1 {
*count = db.RowsAffected
}
return
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance().Set("rows", false)
tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
tx := db.getInstance()
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Row)
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().Set("rows", true)
tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
tx := db.Set("rows", true)
tx.callbacks.Row().Execute(tx)
return tx.Statement.Dest.(*sql.Rows), tx.Error
}
// Scan scans selected value to the struct dest
// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err == nil {
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
}
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
// Pluck used to query single column from a model as a map
// var ages []int64
// db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
}
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
tx.Statement.AddClauseIfNotExists(clause.Select{Columns: []clause.Column{{Name: column}}})
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
tx.callbacks.Query().Execute(tx)
return
}
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Error = tx.Statement.Parse(dest)
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, ScanInitialized)
tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest))
Scan(rows, tx, true)
return tx.Error
}
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%d", spID))
}
}()
}
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
} else {
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
tx := db.Begin(opts...)
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
err = fc(tx.Session(&Session{}))
if err = fc(tx); err == nil {
panicked = false
return tx.Commit().Error
}
if err == nil {
err = tx.Commit().Error
}
panicked = false
return
}
// Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
// Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
tx = db.getInstance()
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
var opt *sql.TxOptions
var err error
if len(opts) > 0 {
opt = opts[0]
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
tx.AddError(err)
}
} else {
tx.AddError(ErrInvalidTransaction)
}
if err != nil {
tx.AddError(err)
}
return tx
return
}
// Commit commits the changes in a transaction
// Commit commit a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
// Rollback rollbacks the changes in a transaction
// Rollback rollback a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Rollback())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec executes raw sql
// Exec execute raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return tx.callbacks.Raw().Execute(tx)
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
tx.callbacks.Raw().Execute(tx)
return
}
func (db *DB) RecordNotFound() bool {
return errors.Is(db.Error, ErrRecordNotFound)
}

View File

@ -1,605 +0,0 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

5
go.mod
View File

@ -1,9 +1,8 @@
module gorm.io/gorm
go 1.18
go 1.14
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
github.com/jinzhu/now v1.1.1
)

6
go.sum
View File

@ -1,6 +0,0 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

475
gorm.go
View File

@ -2,10 +2,7 @@ package gorm
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
@ -14,51 +11,19 @@ import (
"gorm.io/gorm/schema"
)
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -66,39 +31,11 @@ type Config struct {
ConnPool ConnPool
// Dialector database dialector
Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
}
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
}
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
// DB GORM DB definition
type DB struct {
*Config
@ -110,66 +47,21 @@ type DB struct {
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
DryRun bool
WithConditions bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
}
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
func Open(dialector Dialector, config *Config) (db *DB, err error) {
if config == nil {
config = &Config{}
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
config.NamingStrategy = schema.NamingStrategy{}
}
if config.Logger == nil {
@ -184,10 +76,6 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config.Dialector = dialector
}
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
@ -200,48 +88,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
if dialector != nil {
err = dialector.Initialize(db)
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return
}
@ -252,86 +101,33 @@ func (db *DB) Session(config *Session) *DB {
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
Error: db.Error,
clone: 1,
}
)
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
}
if config.Context != nil {
if tx.Statement != nil {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
} else {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
ConnPool: tx.ConnPool,
}
}
tx.Statement.Context = config.Context
}
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
if config.WithConditions {
tx.clone = 3
}
if config.DryRun {
tx.Config.DryRun = true
}
if config.QueryFields {
tx.Config.QueryFields = true
}
if config.Logger != nil {
tx.Config.Logger = config.Logger
}
@ -340,23 +136,19 @@ func (db *DB) Session(config *Session) *DB {
tx.Config.NowFunc = config.NowFunc
}
if config.Initialized {
tx = tx.getInstance()
}
return tx
}
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
return db.Session(&Session{WithConditions: true, Context: ctx})
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
return db.Session(&Session{
WithConditions: true,
Logger: db.Logger.LogMode(logger.Info),
})
}
@ -369,7 +161,10 @@ func (db *DB) Set(key string, value interface{}) *DB {
// Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) {
return db.Statement.Settings.Load(key)
if db.Statement != nil {
return db.Statement.Settings.Load(key)
}
return nil, false
}
// InstanceSet store value with key into current db instance's context
@ -381,7 +176,47 @@ func (db *DB) InstanceSet(key string, value interface{}) *DB {
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
if db.Statement != nil {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
return nil, false
}
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
if err := stmt.Parse(model); err == nil {
modelSchema = stmt.Schema
} else {
return err
}
if err := stmt.Parse(joinTable); err == nil {
joinSchema = stmt.Schema
} else {
return err
}
if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil {
for _, ref := range relation.References {
if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil {
f.DataType = ref.ForeignKey.DataType
ref.ForeignKey = f
} else {
return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
}
}
relation.JoinTable = joinSchema
} else {
return fmt.Errorf("failed to found relation: %v", field)
}
return nil
}
// Callback returns callback manager
@ -389,68 +224,50 @@ func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// AddError add error to db
func (db *DB) AddError(err error) error {
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
if db.Error == nil {
db.Error = err
} else if err != nil {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
return db.Error
}
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
tx := &DB{Config: db.Config}
if db.clone == 1 {
// clone with new statement
switch db.clone {
case 1: // clone with new statement
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
db.clone = 1
tx.Statement = db.Statement
case 3: // with clone statement
if db.Statement != nil {
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
}
if tx.Statement == nil {
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
DB: tx,
Clauses: map[string]clause.Clause{},
}
}
if db.Statement != nil {
tx.Statement.Context = db.Statement.Context
tx.Statement.ConnPool = db.Statement.ConnPool
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
tx.Statement.Context = context.Background()
tx.Statement.ConnPool = db.ConnPool
}
return tx
@ -459,86 +276,6 @@ func (db *DB) getInstance() *DB {
return db
}
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
err := stmt.Parse(model)
if err != nil {
return err
}
modelSchema = stmt.Schema
err = stmt.Parse(joinTable)
if err != nil {
return err
}
joinSchema = stmt.Schema
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
}
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
}
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}

View File

@ -14,79 +14,60 @@ type Dialector interface {
Initialize(*DB) error
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
DefaultValueOf(*schema.Field) clause.Expression
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
QuoteTo(clause.Writer, string)
Explain(sql string, vars ...interface{}) string
}
// Plugin GORM plugin interface
type Plugin interface {
Name() string
Initialize(*DB) error
}
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// SavePointerDialectorInterface save pointer interface
type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
type TxCommiter interface {
Commit() error
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
type BeforeCreateInterface interface {
BeforeCreate(*DB) error
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
type AfterCreateInterface interface {
AfterCreate(*DB) error
}
// GetDBConnector SQL db connector
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
type BeforeUpdateInterface interface {
BeforeUpdate(*DB) error
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
type AfterUpdateInterface interface {
AfterUpdate(*DB) error
}
type ErrorTranslator interface {
Translate(err error) error
type BeforeSaveInterface interface {
BeforeSave(*DB) error
}
type AfterSaveInterface interface {
AfterSave(*DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*DB) error
}
type AfterFindInterface interface {
AfterFind(*DB) error
}

View File

@ -1,493 +0,0 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -1,183 +0,0 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -2,9 +2,6 @@ package logger
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"time"
@ -12,9 +9,6 @@ import (
"gorm.io/gorm/utils"
)
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors
const (
Reset = "\033[0m"
@ -25,23 +19,18 @@ const (
Magenta = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m"
YellowBold = "\033[33;1m"
)
// LogLevel log level
// LogLevel
type LogLevel int
const (
// Silent silent log level
Silent LogLevel = iota + 1
// Error error log level
Error
// Warn warn log level
Warn
// Info info log level
Info
)
@ -50,13 +39,10 @@ type Writer interface {
Printf(string, ...interface{})
}
// Config logger config
type Config struct {
SlowThreshold time.Duration
Colorful bool
IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel
SlowThreshold time.Duration
Colorful bool
LogLevel LogLevel
}
// Interface logger interface
@ -65,46 +51,32 @@ type Interface interface {
Info(context.Context, string, ...interface{})
Warn(context.Context, string, ...interface{})
Error(context.Context, string, ...interface{})
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
}
var (
// Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 100 * time.Millisecond,
LogLevel: Warn,
Colorful: true,
})
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
)
// New initialize logger
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
warnStr = "%s\n[warn] "
errStr = "%s\n[error] "
traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
traceStr = "%s\n[%v] [rows:%d] %s"
traceWarnStr = "%s\n[%v] [rows:%d] %s"
traceErrStr = "%s %s\n[%v] [rows:%d] %s"
)
if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s"
}
return &logger{
@ -134,92 +106,40 @@ func (l *logger) LogMode(level LogLevel) Interface {
}
// Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel > 0 {
elapsed := time.Now().Sub(begin)
switch {
case err != nil && l.LogLevel >= Error:
sql, rows := fc()
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
case l.LogLevel >= Info:
sql, rows := fc()
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -9,172 +9,88 @@ import (
"strings"
"time"
"unicode"
"gorm.io/gorm/utils"
)
const (
tmFmtWithMS = "2006-01-02 15:04:05.999"
tmFmtZero = "0000-00-00 00:00:00"
nullStr = "NULL"
)
func isPrintable(s string) bool {
func isPrintable(s []byte) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
if !unicode.IsPrint(rune(r)) {
return false
}
}
return true
}
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
var convertParams func(interface{}, int)
var vars = make([]interface{}, len(avars))
copy(vars, avars)
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {
case bool:
vars[idx] = strconv.FormatBool(v)
vars[idx] = fmt.Sprint(v)
case time.Time:
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
} else {
vars[idx] = nullStr
}
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = nullStr
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper
}
case []byte:
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v)
case float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
vars[idx] = fmt.Sprintf("%d", v)
case float64, float32:
vars[idx] = fmt.Sprintf("%.6f", v)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = nullStr
} else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
if v == nil {
vars[idx] = "NULL"
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
return
rv := reflect.ValueOf(v)
if !rv.IsValid() {
vars[idx] = "NULL"
} else if rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = "NULL"
} else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else {
for _, t := range convertableTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
return
}
}
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
}
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
}
}
}
for idx, v := range avars {
for idx, v := range vars {
convertParams(v, idx)
}
if numericPlaceholder == nil {
var idx int
var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
}
}
newSQL.WriteByte(v)
for _, v := range vars {
sql = strings.Replace(sql, "?", v.(string), 1)
}
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1)
}
}
return sql

View File

@ -1,54 +1,20 @@
package logger_test
import (
"database/sql/driver"
"encoding/json"
"fmt"
"regexp"
"strings"
"testing"
"github.com/jinzhu/now"
"gorm.io/gorm/logger"
)
type JSON json.RawMessage
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
return json.RawMessage(j).MarshalJSON()
}
type ExampleStruct struct {
Name string
Val string
}
func (s ExampleStruct) Value() (driver.Value, error) {
return json.Marshal(s)
}
func format(v []byte, escaper string) string {
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
}
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
type intType int
type floatType float64
var (
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password([]byte("pass"))
)
results := []struct {
@ -61,67 +27,25 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
NumericRegexp: regexp.MustCompile("@p(\\d+)"),
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)",
NumericRegexp: regexp.MustCompile(`\$(\d+)`),
NumericRegexp: regexp.MustCompile("\\$(\\d+)"),
Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)",
NumericRegexp: regexp.MustCompile(`@p(\d+)`),
NumericRegexp: regexp.MustCompile("@p(\\d+)"),
Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`,
},
}

View File

@ -1,97 +1,41 @@
package gorm
import (
"reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"database/sql"
)
// Migrator returns migrator
func (db *DB) Migrator() Migrator {
tx := db.getInstance()
// apply scopes to migrator
for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes()
}
return tx.Dialector.Migrator(tx.Session(&Session{}))
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
return db.Dialector.Migrator(db)
}
// ViewOption view option
type ViewOption struct {
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB // required subquery.
Replace bool
CheckOption string
Query *DB
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...interface{}) error
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables
CreateTable(dst ...interface{}) error
DropTable(dst ...interface{}) error
HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns
AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)
ColumnTypes(dst interface{}) ([]*sql.ColumnType, error)
// Views
CreateView(name string, option ViewOption) error
@ -107,5 +51,4 @@ type Migrator interface {
DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
}

View File

@ -1,107 +0,0 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

View File

@ -1,43 +0,0 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

File diff suppressed because it is too large Load Diff

View File

@ -1,33 +0,0 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

View File

@ -3,11 +3,10 @@ package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
//
// type User struct {
// gorm.Model
// }
// It may be embeded into your model or you may build your own model without it
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time

View File

@ -1,206 +0,0 @@
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type PreparedStmtDB struct {
Stmts stmt_store.Store
Mux *sync.RWMutex
ConnPool
}
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
}
}
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, stmt.Error()
}
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, stmt.Error()
}
}
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return result, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
}
}
return rows, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return result, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return rows, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

406
scan.go
View File

@ -2,368 +2,148 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0
func Scan(rows *sql.Rows, db *DB, initialized bool) {
columns, _ := rows.Columns()
values := make([]interface{}, len(columns))
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
for idx, _ := range columns {
values[idx] = new(interface{})
}
if initialized || rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v
}
mapValue, ok := dest.(map[string]interface{})
if ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
scanIntoMap(mapValue, values, columns)
}
for idx, column := range columns {
mapValue[column] = *(values[idx].(*interface{}))
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
for idx, _ := range columns {
values[idx] = new(interface{})
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
v := map[string]interface{}{}
for idx, column := range columns {
v[column] = *(values[idx].(*interface{}))
}
*dest = append(*dest, v)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
case *int, *int64, *uint, *uint64:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
var (
fields = make([]*schema.Field, len(columns))
joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][]*schema.Field, len(columns))
}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue
}
}
var val interface{}
values[idx] = &val
} else {
var val interface{}
values[idx] = &val
}
}
}
}
switch reflectValue.Kind() {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
reflectValueType := db.Statement.ReflectValue.Type().Elem()
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
fields := make([]*schema.Field, len(columns))
joinFields := make([][2]*schema.Field, len(columns))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
for initialized || rows.Next() {
BEGIN:
initialized = false
elem := reflect.New(reflectValueType).Elem()
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++
goto BEGIN
}
}
}
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
values[0] = elem.Addr().Interface()
} else {
elem = reflect.New(reflectValueType)
}
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
} else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
}
}
case reflect.Struct:
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
}
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
default:
db.AddError(rows.Scan(dest))
}
}
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
db.AddError(ErrRecordNotFound)
}
}

View File

@ -9,7 +9,8 @@ import (
"gorm.io/gorm/schema"
)
type UserWithCallback struct{}
type UserWithCallback struct {
}
func (UserWithCallback) BeforeSave(*gorm.DB) error {
return nil

32
schema/check.go Normal file
View File

@ -0,0 +1,32 @@
package schema
import (
"regexp"
"strings"
)
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
var checks = map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

View File

@ -6,7 +6,6 @@ import (
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserCheck struct {
@ -21,7 +20,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err)
}
results := map[string]schema.CheckConstraint{
results := map[string]schema.Check{
"name_checker": {
Name: "name_checker",
Constraint: "name <> 'jinzhu'",
@ -54,31 +53,3 @@ func TestParseCheck(t *testing.T) {
}
}
}
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

View File

@ -1,66 +0,0 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
package schema_test
import (
"context"
"database/sql"
"reflect"
"sync"
@ -20,7 +19,6 @@ func TestFieldValuerAndSetter(t *testing.T) {
Model: gorm.Model{
ID: 10,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true},
},
Name: "valuer_and_setter",
@ -36,7 +34,6 @@ func TestFieldValuerAndSetter(t *testing.T) {
"name": user.Name,
"id": user.ID,
"created_at": user.CreatedAt,
"updated_at": user.UpdatedAt,
"deleted_at": user.DeletedAt,
"age": user.Age,
"birthday": user.Birthday,
@ -44,36 +41,30 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
checkField(t, userSchema, reflectValue, values)
var f *bool
// test setter
newValues := map[string]interface{}{
"name": "valuer_and_setter_2",
"id": 2,
"created_at": time.Now(),
"updated_at": nil,
"deleted_at": time.Now(),
"age": 20,
"birthday": time.Now(),
"active": f,
"active": false,
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
newValues["updated_at"] = time.Time{}
newValues["active"] = false
checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type
age := myint(10)
var nilTime *time.Time
newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(),
"updated_at": nilTime,
"deleted_at": time.Now(),
"age": &age,
"birthday": mytime(time.Now()),
@ -81,11 +72,10 @@ func TestFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
newValues2["updated_at"] = time.Time{}
checkField(t, userSchema, reflectValue, newValues2)
}
@ -133,7 +123,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -152,7 +142,7 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -203,7 +193,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -220,7 +210,7 @@ func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
@ -235,8 +225,6 @@ type UserWithPermissionControl struct {
Name4 string `gorm:"<-:create"`
Name5 string `gorm:"<-:update"`
Name6 string `gorm:"<-:create,update"`
Name7 string `gorm:"->:false;<-:create,update"`
Name8 string `gorm:"->;-:migration"`
}
func TestParseFieldWithPermission(t *testing.T) {
@ -245,90 +233,17 @@ func TestParseFieldWithPermission(t *testing.T) {
t.Fatalf("Failed to parse user with permission, got error %v", err)
}
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String, Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false},
{Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true},
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true},
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true},
{Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
{Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true},
{Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: false},
{Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: false},
{Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: false},
{Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: false},
}
for _, f := range fields {
checkSchemaField(t, user, f, func(f *schema.Field) {})
}
}
type (
ID int64
INT int
INT8 int8
INT16 int16
INT32 int32
INT64 int64
UINT uint
UINT8 uint8
UINT16 uint16
UINT32 uint32
UINT64 uint64
FLOAT32 float32
FLOAT64 float64
BOOL bool
STRING string
TIME time.Time
BYTES []byte
TypeAlias struct {
ID
INT `gorm:"column:fint"`
INT8 `gorm:"column:fint8"`
INT16 `gorm:"column:fint16"`
INT32 `gorm:"column:fint32"`
INT64 `gorm:"column:fint64"`
UINT `gorm:"column:fuint"`
UINT8 `gorm:"column:fuint8"`
UINT16 `gorm:"column:fuint16"`
UINT32 `gorm:"column:fuint32"`
UINT64 `gorm:"column:fuint64"`
FLOAT32 `gorm:"column:ffloat32"`
FLOAT64 `gorm:"column:ffloat64"`
BOOL `gorm:"column:fbool"`
STRING `gorm:"column:fstring"`
TIME `gorm:"column:ftime"`
BYTES `gorm:"column:fbytes"`
}
)
func TestTypeAliasField(t *testing.T) {
alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err)
}
fields := []*schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true},
{Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`},
{Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`},
{Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`},
{Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`},
{Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`},
{Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`},
{Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`},
{Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`},
{Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`},
{Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`},
{Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`},
{Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`},
{Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`},
{Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`},
{Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`},
{Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`},
}
for _, f := range fields {
checkSchemaField(t, alias, f, func(f *schema.Field) {})
checkSchemaField(t, user, &f, func(f *schema.Field) {})
}
}

View File

@ -1,8 +1,6 @@
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
)
@ -13,8 +11,7 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
Fields []IndexOption
}
type IndexOption struct {
@ -23,28 +20,16 @@ type IndexOption struct {
Sort string // DESC, ASC
Collate string
Length int
Priority int
}
// ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() []*Index {
indexesByName := map[string]*Index{}
indexes := []*Index{}
func (schema *Schema) ParseIndexes() map[string]Index {
var indexes = map[string]Index{}
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE_INDEX"] != "" {
for _, index := range parseFieldIndexes(field) {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
@ -58,37 +43,25 @@ func (schema *Schema) ParseIndexes() []*Index {
if idx.Comment == "" {
idx.Comment = index.Comment
}
if idx.Option == "" {
idx.Option = index.Option
}
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].Priority < idx.Fields[j].Priority
})
indexes[index.Name] = idx
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes
}
func (schema *Schema) LookIndex(name string) *Index {
if schema != nil {
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return index
}
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
}
}
}
@ -96,72 +69,53 @@ func (schema *Schema) LookIndex(name string) *Index {
return nil
}
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
func parseFieldIndexes(field *Field) (indexes []Index) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
if k == "INDEX" || k == "UNIQUE_INDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.IndexByte(tag, ',')
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
settings = ParseTagSetting(tag, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
name = tag[0:idx]
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
name = field.Schema.namer.IndexName(field.Schema.Table, field.Name)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
if (k == "UNIQUE_INDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE"
}
priority, err := strconv.Atoi(settings["PRIORITY"])
if err != nil {
priority = 10
}
indexes = append(indexes, Index{
Name: name,
Class: settings["CLASS"],
Type: settings["TYPE"],
Where: settings["WHERE"],
Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []IndexOption{{
Field: field,
Expression: settings["EXPRESSION"],
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
Priority: priority,
}},
})
}
}
}
err = nil
return
}

View File

@ -1,58 +1,23 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserIndex struct {
Name string `gorm:"index"`
Name2 string `gorm:"index:idx_name,unique"`
Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"`
Name4 string `gorm:"uniqueIndex"`
Name4 string `gorm:"unique_index"`
Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"`
Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"`
Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"`
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
Data0B string `gorm:"index:,composite:comp_id0"`
// Composite Index: Nested structure.
Data1A string `gorm:"index:,composite:comp_id1"`
CompIdxLevel1C
// Composite Index: Unique and priority.
Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"`
CompIdxLevel2C
}
type CompIdxLevel1C struct {
CompIdxLevel1B
Data1C string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel1B struct {
Data1B string `gorm:"index:,composite:comp_id1"`
}
type CompIdxLevel2C struct {
CompIdxLevel2B
Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"`
}
type CompIdxLevel2B struct {
Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"`
Age int64 `gorm:"index:profile,expression:ABS(age)"`
OID int64 `gorm:"index:idx_id"`
MemberNumber string `gorm:"index:idx_id"`
}
func TestParseIndex(t *testing.T) {
@ -61,215 +26,79 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err)
}
results := []*schema.Index{
{
results := map[string]schema.Index{
"idx_user_indices_name": {
Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
Fields: []schema.IndexOption{{}},
},
{
"idx_name": {
Name: "idx_name",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
Fields: []schema.IndexOption{{}},
},
{
"idx_user_indices_name3": {
Name: "idx_user_indices_name3",
Type: "btree",
Where: "name3 != 'jinzhu'",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Name3"},
Sort: "desc",
Collate: "utf8",
Length: 10,
}},
},
{
"idx_user_indices_name4": {
Name: "idx_user_indices_name4",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
Fields: []schema.IndexOption{{}},
},
{
"idx_user_indices_name5": {
Name: "idx_user_indices_name5",
Class: "FULLTEXT",
Comment: "hello , world",
Where: "age > 10",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
Fields: []schema.IndexOption{{}},
},
{
"profile": {
Name: "profile",
Comment: "hello , world",
Where: "age > 10",
Option: "WITH PARSER parser_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, {
Field: &schema.Field{Name: "Age"},
Fields: []schema.IndexOption{{}, {
Expression: "ABS(age)",
}},
},
{
"idx_id": {
Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
{
Name: "idx_oid",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
},
{
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data0A"},
}, {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
}, {
Field: &schema.Field{Name: "Data1B"},
}, {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data2C"},
}, {
Field: &schema.Field{Name: "Data2A"},
}, {
Field: &schema.Field{Name: "Data2B"},
}},
Fields: []schema.IndexOption{{}, {}},
},
}
CheckIndices(t, results, user.ParseIndexes())
}
indices := user.ParseIndexes()
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
type IndexTest struct {
FieldA string `gorm:"unique;index"` // unique and index
FieldB string `gorm:"unique"` // unique
for k, result := range results {
v, ok := indices[k]
if !ok {
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
}
FieldC string `gorm:"index:,unique"` // uniqueIndex
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
}
for i, ei := range expected {
t.Run(ei.Name, func(t *testing.T) {
ai := actual[i]
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
if len(ei.Fields) != len(ai.Fields) {
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
return
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment"} {
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
t.Errorf(
"index %v %v should equal, expects %v, got %v",
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
}
for i, ef := range ei.Fields {
af := ai.Fields[i]
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
}
for idx, ef := range result.Fields {
rf := v.Fields[idx]
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
}
})
}
}
}

View File

@ -1,42 +0,0 @@
package schema
import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}

View File

@ -26,11 +26,9 @@ type User struct {
Active *bool
}
type (
mytime time.Time
myint int
mybool = bool
)
type mytime time.Time
type myint int
type mybool = bool
type AdvancedDataTypeUser struct {
ID sql.NullInt64
@ -41,24 +39,3 @@ type AdvancedDataTypeUser struct {
Active mybool
Admin *mybool
}
type BaseModel struct {
ID uint
CreatedAt time.Time
CreatedBy *int
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
type VersionModel struct {
BaseModel
Version int
}
type VersionUser struct {
VersionModel
Name string
Age uint
Birthday *time.Time
}

View File

@ -2,149 +2,92 @@ package schema
import (
"crypto/sha1"
"encoding/hex"
"regexp"
"fmt"
"strings"
"sync"
"unicode/utf8"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
JoinTableName(table string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
TablePrefix string
SingularTable bool
}
// TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
return ns.TablePrefix + toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
return ns.TablePrefix + inflection.Plural(toDBName(str))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
return toDBName(column)
}
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if !ns.NoLowerCase && strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
return ns.TablePrefix + inflection.Plural(toDBName(str))
}
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Field.Name))
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return ns.formatName("chk", table, column)
return fmt.Sprintf("chk_%s_%s", table, column)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column))
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_")
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
if utf8.RuneCountInString(idxName) > 64 {
h := sha1.New()
h.Write([]byte(formattedName))
h.Write([]byte(idxName))
bs := h.Sum(nil)
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
idxName = fmt.Sprintf("idx%v%v", table, column)[0:56] + string(bs)[:8]
}
return formattedName
return idxName
}
var (
smap sync.Map
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
var commonInitialismsForReplacer []string
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
func (ns NamingStrategy) toDBName(name string) string {
func toDBName(name string) string {
if name == "" {
return ""
}
if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
return name
} else if v, ok := smap.Load(name); ok {
return fmt.Sprint(v)
}
var (
@ -183,14 +126,6 @@ func (ns NamingStrategy) toDBName(name string) string {
} else {
buf.WriteByte(value[len(value)-1])
}
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
return buf.String()
}

View File

@ -1,12 +1,11 @@
package schema
import (
"strings"
"testing"
)
func TestToDBName(t *testing.T) {
maps := map[string]string{
var maps = map[string]string{
"": "",
"x": "x",
"X": "x",
@ -27,193 +26,9 @@ func TestToDBName(t *testing.T) {
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
}
ns := NamingStrategy{}
for key, value := range maps {
if ns.toDBName(key) != value {
t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key))
}
}
maps = map[string]string{
"x": "X",
"user_restrictions": "UserRestriction",
"this_is_a_test": "ThisIsATest",
"abc_and_jkl": "AbcAndJkl",
"employee_id": "EmployeeID",
"field_x": "FieldX",
"http_and_smtp": "HTTPAndSMTP",
"http_server_handler_for_url_id": "HTTPServerHandlerForURLID",
"uuid": "UUID",
"http_url": "HTTPURL",
"sha256_hash": "Sha256Hash",
"this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID",
}
for key, value := range maps {
if ns.SchemaName(key) != value {
t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key))
if toDBName(key) != value {
t.Errorf("%v toName should equal %v, but got %v", key, value, toDBName(key))
}
}
}
func TestNamingStrategy(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: strings.NewReplacer("CID", "Cid"),
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.user_language" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "name_cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
type CustomReplacer struct {
f func(string) string
}
func (r CustomReplacer) Replace(name string) string {
return r.f(name)
}
func TestCustomReplacer(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
func(name string) string {
replaced := "REPLACED_" + strings.ToUpper(name)
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
},
},
NoLowerCase: false,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_replaced_name" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here.
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.replaced_userlanguage" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.replaced_company" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "replaced_name_cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
func TestCustomReplacerWithNoLowerCase(t *testing.T) {
ns := NamingStrategy{
TablePrefix: "public.",
SingularTable: true,
NameReplacer: CustomReplacer{
func(name string) string {
replaced := "REPLACED_" + strings.ToUpper(name)
return strings.NewReplacer("CID", "_Cid").Replace(replaced)
},
},
NoLowerCase: true,
}
idxName := ns.IndexName("public.table", "name")
if idxName != "idx_public_table_REPLACED_NAME" {
t.Errorf("invalid index name generated, got %v", idxName)
}
chkName := ns.CheckerName("public.table", "name")
if chkName != "chk_public_table_name" {
t.Errorf("invalid checker name generated, got %v", chkName)
}
joinTable := ns.JoinTableName("user_languages")
if joinTable != "public.REPLACED_USER_LANGUAGES" {
t.Errorf("invalid join table generated, got %v", joinTable)
}
joinTable2 := ns.JoinTableName("UserLanguage")
if joinTable2 != "public.REPLACED_USERLANGUAGE" {
t.Errorf("invalid join table generated, got %v", joinTable2)
}
tableName := ns.TableName("Company")
if tableName != "public.REPLACED_COMPANY" {
t.Errorf("invalid table name generated, got %v", tableName)
}
columdName := ns.ColumnName("", "NameCID")
if columdName != "REPLACED_NAME_Cid" {
t.Errorf("invalid column name generated, got %v", columdName)
}
}
func TestFormatNameWithStringLongerThan63Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 63}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
ns := NamingStrategy{IdentifierMaxLength: 64}
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
t.Errorf("invalid formatted name generated, got %v", formattedName)
}
}
func TestReplaceEmptyTableName(t *testing.T) {
ns := NamingStrategy{
SingularTable: true,
NameReplacer: strings.NewReplacer("Model", ""),
}
tableName := ns.TableName("Model")
if tableName != "Model" {
t.Errorf("invalid table name generated, got %v", tableName)
}
}

View File

@ -1,19 +0,0 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

View File

@ -1,16 +1,12 @@
package schema
import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
"sync"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause"
)
@ -22,7 +18,6 @@ const (
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
has RelationshipType = "has"
)
type Relationships struct {
@ -31,10 +26,6 @@ type Relationships struct {
HasMany []*Relationship
Many2Many []*Relationship
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
}
type Relationship struct {
@ -62,7 +53,7 @@ type Reference struct {
OwnPrimaryKey bool
}
func (schema *Schema) parseRelation(field *Field) *Relationship {
func (schema *Schema) parseRelation(field *Field) {
var (
err error
fieldValue = reflect.New(field.IndirectFieldType).Interface()
@ -75,38 +66,25 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
}
)
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
if polymorphic, _ := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic)
} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
schema.guessRelation(relation, field, guessBelongs)
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
schema.guessRelation(relation, field, guessGuess)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
}
}
if relation.Type == has {
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
}
if relation.Type == "has" {
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
relation.Type = HasOne
@ -116,7 +94,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
}
if schema.err == nil {
schema.setRelation(relation)
schema.Relationships.Relations[relation.Name] = relation
switch relation.Type {
case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
@ -128,104 +106,36 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
}
}
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
@ -237,25 +147,12 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
}
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField,
@ -264,7 +161,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
})
}
relation.Type = has
relation.Type = "has"
}
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
@ -274,8 +171,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
@ -289,7 +185,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := schema.LookUpField(foreignKey); field != nil {
ownForeignFields = append(ownForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
return
}
}
@ -301,31 +197,33 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
refForeignFields = append(refForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
return
}
}
}
for idx, ownField := range ownForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
joinFieldName := schema.Name + ownField.Name
if len(joinForeignKeys) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
joinFieldName = joinForeignKeys[idx]
}
ownFieldsMap[joinFieldName] = ownField
ownFieldsMap[joinFieldName] = true
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
Tag: removeSettingFromTag(ownField.StructField.Tag, "column"),
})
}
for idx, relField := range refForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
joinFieldName := relation.FieldSchema.Name + relField.Name
if len(joinReferences) > idx {
joinFieldName = joinReferences[idx]
}
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
@ -335,254 +233,100 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
if len(joinReferences) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(relField.StructField.Tag, "column"),
})
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name
if relName == relRefName {
relRefName = relation.Field.Name
}
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
Name: relName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.Schema,
}
} else {
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
}
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
Name: relRefName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.FieldSchema,
}
} else {
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
}
// build references
for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys
if copyableDataType(fieldsMap[f.Name].DataType) {
f.DataType = fieldsMap[f.Name].DataType
}
f.GORMDataType = fieldsMap[f.Name].GORMDataType
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
// use same data type for foreign keys
f.DataType = fieldsMap[f.Name].DataType
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
}
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name],
})
}
return
}
type guessLevel int
const (
guessGuess guessLevel = iota
guessBelongs
guessEmbeddedBelongs
guessHas
guessEmbeddedHas
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
var (
primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema
gl = cgl
)
if gl == guessGuess {
if field.Schema == relation.FieldSchema {
gl = guessBelongs
} else {
gl = guessHas
}
}
reguessOrErr := func() {
switch cgl {
case guessGuess:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
}
}
switch gl {
case guessBelongs:
if !guessHas {
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
reguessOrErr := func(err string, args ...interface{}) {
if guessHas {
schema.guessRelation(relation, field, false)
} else {
schema.err = fmt.Errorf(err, args...)
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys {
f := foreignSchema.LookUpField(foreignKey)
if f == nil {
reguessOrErr()
if f := foreignSchema.LookUpField(foreignKey); f != nil {
foreignFields = append(foreignFields, f)
} else {
reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys)
return
}
foreignFields = append(foreignFields, f)
}
} else {
primarySchemaName := primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
primaryFields = append(primaryFields, f)
}
}
} else {
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
for _, primaryField := range primarySchema.PrimaryFields {
lookUpName := schema.Name + primaryField.Name
if !guessHas {
lookUpName = field.Name + primaryField.Name
}
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
if f := foreignSchema.LookUpField(lookUpName); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
}
}
}
switch {
case len(foreignFields) == 0:
reguessOrErr()
if len(foreignFields) == 0 {
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
return
case len(relation.primaryKeys) > 0:
} else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 {
primaryFields = append(primaryFields, f)
} else if f != primaryFields[idx] {
reguessOrErr()
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
return
}
} else {
reguessOrErr()
reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys)
return
}
}
case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
} else if len(primaryFields) == 0 {
if len(foreignFields) == 1 {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
} else {
reguessOrErr()
reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name)
return
}
}
@ -590,29 +334,22 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
// build references
for idx, foreignField := range foreignFields {
// use same data type for foreign keys
if copyableDataType(primaryFields[idx].DataType) {
foreignField.DataType = primaryFields[idx].DataType
}
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
foreignField.DataType = primaryFields[idx].DataType
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],
ForeignKey: foreignField,
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
OwnPrimaryKey: schema == primarySchema && guessHas,
})
}
if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = has
if guessHas {
relation.Type = "has"
} else {
relation.Type = BelongsTo
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
@ -624,67 +361,19 @@ type Constraint struct {
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
return nil
}
if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true
for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
break
}
}
if matched {
return nil
}
}
}
}
var (
name string
idx = strings.IndexByte(str, ',')
idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",")
)
// optimize match english letters and midline
// The following code is basically called in for.
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
// it only needs to be done once outside, so optimization is done here.
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
if idx != -1 && regexp.MustCompile("^[A-Za-z-_]+$").MatchString(str[0:idx]) {
name = str[0:idx]
} else {
name = rel.Schema.namer.RelationshipFKName(*rel)
@ -695,33 +384,29 @@ func (rel *Relationship) ParseConstraint() *Constraint {
Field: rel.Field,
OnUpdate: settings["ONUPDATE"],
OnDelete: settings["ONDELETE"],
Schema: rel.Schema,
}
for _, ref := range rel.References {
if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
if ref.PrimaryKey != nil && !ref.OwnPrimaryKey {
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
constraint.References = append(constraint.References, ref.PrimaryKey)
if ref.OwnPrimaryKey {
constraint.Schema = ref.ForeignKey.Schema
constraint.ReferenceSchema = rel.Schema
} else {
constraint.Schema = rel.Schema
constraint.ReferenceSchema = ref.PrimaryKey.Schema
}
constraint.ReferenceSchema = ref.PrimaryKey.Schema
}
}
if rel.JoinTable != nil || constraint.ReferenceSchema == nil {
return nil
}
return &constraint
}
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
foreignFields := []*Field{}
relForeignKeys := []string{}
if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
@ -755,19 +440,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
}
}
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})
return
}
func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(lowerStr, s) {
return false
}
}
return true
}

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
t.Errorf("Failed to parse schema")
} else {
for _, rel := range relations {
checkSchemaRelation(t, s, rel)
@ -55,95 +55,6 @@ func TestBelongsToOverrideReferences(t *testing.T) {
})
}
func TestBelongsToWithOnlyReferences(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestBelongsToWithOnlyReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type User struct {
gorm.Model
Profile Profile `gorm:"References:Refer"`
ProfileID int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}},
})
}
func TestSelfReferentialBelongsTo(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatorID *int32
Creator *User
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatorID", "User", "", false}},
})
}
func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy *int32
Creator *User `gorm:"foreignKey:CreatedBy;references:ID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}},
})
}
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -181,62 +92,6 @@ func TestHasOneOverrideReferences(t *testing.T) {
})
}
func TestHasOneOverrideReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
}
type User struct {
gorm.Model
ProfileID uint `gorm:"column:profile_id"`
Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}},
})
}
func TestHasOneWithOnlyReferences(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"References:Refer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}},
})
}
func TestHasOneWithOnlyReferences2(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserID uint
}
type User struct {
gorm.Model
Refer string
Profile Profile `gorm:"References:Refer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}},
})
}
func TestHasManyOverrideForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -283,9 +138,8 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) {
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"`
Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"`
Refer uint
Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
@ -295,13 +149,6 @@ func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) {
{"Refer", "User", "UserReferID", "user_profiles", "", true},
{"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false},
},
}, Relation{
Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"},
References: []Reference{
{"Refer", "User", "User_refer_id", "user_profiles2", "", true},
{"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false},
},
})
}
@ -328,33 +175,6 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
})
}
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -364,39 +184,16 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
},
})
}
func TestBuildReadonlyMany2ManyRelation(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"`
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"},
References: []Reference{
{"ID", "User", "UserReferID", "user_profile", "", true},
{"ID", "Profile", "ProfileRefer", "user_profile", "", false},
{"ID", "User", "UserReferID", "user_profiles", "", true},
{"ID", "Profile", "ProfileRefer", "user_profiles", "", false},
},
})
}
@ -448,551 +245,3 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) {
},
)
}
func TestMultipleMany2Many(t *testing.T) {
type Thing struct {
ID int
}
type Person struct {
ID int
Likes []Thing `gorm:"many2many:likes"`
Dislikes []Thing `gorm:"many2many:dislikes"`
}
checkStructRelation(t, &Person{},
Relation{
Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "likes", Table: "likes"},
References: []Reference{
{"ID", "Person", "PersonID", "likes", "", true},
{"ID", "Thing", "ThingID", "likes", "", false},
},
},
Relation{
Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing",
JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"},
References: []Reference{
{"ID", "Person", "PersonID", "dislikes", "", true},
{"ID", "Thing", "ThingID", "dislikes", "", false},
},
},
)
}
func TestSelfReferentialMany2Many(t *testing.T) {
type User struct {
ID int32 `gorm:"primaryKey"`
Name string
CreatedBy int32
Creators []User `gorm:"foreignKey:CreatedBy"`
AnotherPro interface{} `gorm:"-"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User",
References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}},
})
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse schema")
}
relSchema := user.Relationships.Relations["Creators"].FieldSchema
if user != relSchema {
t.Fatalf("schema should be same, expects %p but got %p", user, relSchema)
}
}
type CreatedByModel struct {
CreatedByID uint
CreatedBy *CreatedUser
}
type CreatedUser struct {
gorm.Model
CreatedByModel
}
func TestEmbeddedRelation(t *testing.T) {
checkStructRelation(t, &CreatedUser{}, Relation{
Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser",
References: []Reference{
{"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false},
},
})
userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse schema, got error %v", err)
}
if len(userSchema.Relationships.Relations) != 1 {
t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations))
}
if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok {
if createdByRel.FieldSchema != userSchema {
t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema)
}
} else {
t.Fatalf("expects created by relations, but not found")
}
}
func TestEmbeddedHas(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type User struct {
ID int
Cat struct {
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
} `gorm:"embedded;embeddedPrefix:cat_"`
Dog struct {
ID int
Name string
UserID int
Toy Toy `gorm:"polymorphic:Owner;"`
Toys []Toy `gorm:"polymorphic:Owner;"`
}
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Name string
}
type Address struct {
CountryID int
Country Country
}
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
AddressID int
Address struct {
ID int
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Errorf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"PostalAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"VisitingAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
"NestedAddress": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
References: []Reference{
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
},
},
},
},
})
}
func TestVariableRelation(t *testing.T) {
var result struct {
User
}
checkStructRelation(t, &result, Relation{
Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account",
References: []Reference{
{"ID", "", "UserID", "Account", "", true},
},
})
checkStructRelation(t, &result, Relation{
Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company",
References: []Reference{
{"ID", "Company", "CompanyID", "", "", false},
},
})
}
func TestSameForeignKey(t *testing.T) {
type UserAux struct {
gorm.Model
Aux string
UUID string
}
type User struct {
gorm.Model
Name string
UUID string
Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"`
}
checkStructRelation(t, &User{},
Relation{
Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux",
References: []Reference{
{"UUID", "User", "UUID", "UserAux", "", true},
},
},
)
}
func TestBelongsToSameForeignKey(t *testing.T) {
type User struct {
gorm.Model
Name string
UUID string
}
type UserAux struct {
gorm.Model
Aux string
UUID string
User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"`
}
checkStructRelation(t, &UserAux{},
Relation{
Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User",
References: []Reference{
{"UUID", "User", "UUID", "UserAux", "", false},
},
},
)
}
func TestHasOneWithSameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
ProfileRefer int // not used in relationship
}
type User struct {
gorm.Model
Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"`
ProfileRefer int
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}},
})
}
func TestHasManySameForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
UserRefer uint
}
type User struct {
gorm.Model
UserRefer uint
Profile []Profile `gorm:"ForeignKey:UserRefer"`
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
})
}
type Author struct {
gorm.Model
}
type Book struct {
gorm.Model
Author Author
AuthorID uint
}
func (Book) TableName() string {
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
}
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
s, err := schema.Parse(
&Book{},
&sync.Map{},
schema.NamingStrategy{IdentifierMaxLength: 64},
)
if err != nil {
t.Fatalf("Failed to parse schema")
}
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
constraint := s.Relationships.Relations["Author"].ParseConstraint()
if constraint.Name != expectedConstraintName {
t.Fatalf(
"expected constraint name %s, got %s",
expectedConstraintName,
constraint.Name,
)
}
}

View File

@ -5,29 +5,13 @@ import (
"errors"
"fmt"
"go/ast"
"path"
"reflect"
"strings"
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -41,9 +25,8 @@ type Schema struct {
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
@ -55,23 +38,37 @@ type Schema struct {
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
type CreateClausesInterface interface {
CreateClauses() []clause.Interface
}
type QueryClausesInterface interface {
QueryClauses() []clause.Interface
}
type UpdateClausesInterface interface {
UpdateClauses() []clause.Interface
}
type DeleteClausesInterface interface {
DeleteClauses() []clause.Interface
}
func (schema Schema) String() string {
if schema.ModelType.Name() == "" {
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
}
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
}
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 0)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
@ -85,57 +82,10 @@ func (schema Schema) LookUpField(name string) *Field {
return nil
}
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
// get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
@ -143,63 +93,30 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
Name: modelType.Name(),
ModelType: modelType,
Table: namer.TableName(modelType.Name()),
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
@ -216,67 +133,52 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey {
if schema.PrioritizedPrimaryField == v {
schema.PrioritizedPrimaryField = nil
}
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
} else if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = f
}
}
}
if field.PrimaryKey {
if schema.PrioritizedPrimaryField == nil {
schema.PrioritizedPrimaryField = field
}
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
if _, ok := schema.FieldsByName[field.Name]; !ok {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter(modelType)
field.setupValuerAndSetter()
}
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
if f := schema.LookUpField("id"); f != nil {
if f.PrimaryKey {
schema.PrioritizedPrimaryField = f
} else if len(schema.PrimaryFields) == 0 {
prioritizedPrimaryField.PrimaryKey = true
schema.PrioritizedPrimaryField = prioritizedPrimaryField
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
}
}
if schema.PrioritizedPrimaryField == nil {
if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
f.PrimaryKey = true
schema.PrioritizedPrimaryField = f
schema.PrimaryFields = append(schema.PrimaryFields, f)
}
}
@ -284,148 +186,43 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
}
for _, field := range schema.Fields {
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
schema.FieldsWithDefaultDBValue = map[string]*Field{}
for db, field := range schema.FieldsByDBName {
if field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue[db] = field
}
}
if field := schema.PrioritizedPrimaryField; field != nil {
switch field.GORMDataType {
if schema.PrioritizedPrimaryField != nil {
switch schema.PrioritizedPrimaryField.DataType {
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
field.AutoIncrement = true
}
schema.FieldsWithDefaultDBValue[schema.PrioritizedPrimaryField.DBName] = schema.PrioritizedPrimaryField
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
reflectValue := reflect.New(modelType)
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error":
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name)
}
}
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
cacheStore.Store(modelType, schema)
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
// parse relations for unidentified fields
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
}
}
}
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}

Some files were not shown because too many files have changed in this diff Show More