Compare commits
87 Commits
fix/create
...
master
Author | SHA1 | Date | |
---|---|---|---|
![]() |
49eaeacb89 | ||
![]() |
52b4744410 | ||
![]() |
9af6d510b5 | ||
![]() |
c63374f5d1 | ||
![]() |
b9c7e562b0 | ||
![]() |
985940f0d8 | ||
![]() |
991c2d4891 | ||
![]() |
751a6dde7a | ||
![]() |
2f4925e017 | ||
![]() |
1e8baf5459 | ||
![]() |
842ee527eb | ||
![]() |
23c0d7cf05 | ||
![]() |
718eae4fdd | ||
![]() |
49b01a3e93 | ||
![]() |
c44405a25b | ||
![]() |
751c1d6b45 | ||
![]() |
8e7ab46c1b | ||
![]() |
e3037e4ef0 | ||
![]() |
1204330419 | ||
![]() |
9703eb775f | ||
![]() |
1c966e0d25 | ||
![]() |
e5b867e785 | ||
![]() |
8c4e8e2d2a | ||
![]() |
a827495be1 | ||
![]() |
489a563293 | ||
![]() |
42bd4f603c | ||
![]() |
a9d27293de | ||
![]() |
3876ffe4bb | ||
![]() |
ee3b549d7d | ||
![]() |
9f273777f5 | ||
![]() |
9ca84b3dde | ||
![]() |
86b1d22911 | ||
![]() |
fed49230cb | ||
![]() |
8503287ca4 | ||
![]() |
4ef3af10ed | ||
![]() |
f482f25c71 | ||
![]() |
6bfccf8afa | ||
![]() |
49bbaa637f | ||
![]() |
b0d70a26d1 | ||
![]() |
deceebfab8 | ||
![]() |
52e3b353eb | ||
![]() |
8020e8c166 | ||
![]() |
62bd0b9331 | ||
![]() |
c6ac54812a | ||
![]() |
68434b76eb | ||
![]() |
c2515ce260 | ||
![]() |
7f75b12bb2 | ||
![]() |
0daaf1747c | ||
![]() |
0dbfda5d7e | ||
![]() |
4a50b36f63 | ||
![]() |
11c4331058 | ||
![]() |
8a0af58cc5 | ||
![]() |
4f6291154b | ||
![]() |
109f239fae | ||
![]() |
79bf7f92ed | ||
![]() |
3d09f7947f | ||
![]() |
73a988ceb2 | ||
![]() |
05167fd591 | ||
![]() |
78c6dfd712 | ||
![]() |
3fe7fcf356 | ||
![]() |
9c4070ed19 | ||
![]() |
49d524aaea | ||
![]() |
49d94c173c | ||
![]() |
0f105ec163 | ||
![]() |
5e599a07ec | ||
![]() |
9d370bcb3e | ||
![]() |
78920199f0 | ||
![]() |
ac59252327 | ||
![]() |
207f1ac68f | ||
![]() |
85299bfca7 | ||
![]() |
5553ff3dcb | ||
![]() |
bc49365de2 | ||
![]() |
d0b4ceb726 | ||
![]() |
9a61ef2af8 | ||
![]() |
1e13fd7543 | ||
![]() |
1b48aa072d | ||
![]() |
26195e6d16 | ||
![]() |
956f7ce843 | ||
![]() |
0d6c5345f3 | ||
![]() |
57603882ea | ||
![]() |
81536f823c | ||
![]() |
1b0aa802df | ||
![]() |
e0c3be03fb | ||
![]() |
303de6e7c8 | ||
![]() |
f7ebf049da | ||
![]() |
ab89d54d87 | ||
![]() |
281f3e369a |
20
.github/release-drafter.yml
vendored
Normal file
20
.github/release-drafter.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
|
||||
tag-template: 'v$NEXT_PATCH_VERSION'
|
||||
categories:
|
||||
- title: '🚀 Features'
|
||||
labels:
|
||||
- 'feature'
|
||||
- 'enhancement'
|
||||
- title: '🐛 Bug Fixes'
|
||||
labels:
|
||||
- 'fix'
|
||||
- 'bugfix'
|
||||
- 'bug'
|
||||
- title: '🧰 Maintenance'
|
||||
label: 'chore'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&'
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
31
.github/workflows/create-release.yml
vendored
Normal file
31
.github/workflows/create-release.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
create_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate Release Notes and Publish
|
||||
id: generate_release_notes
|
||||
uses: release-drafter/release-drafter@v6
|
||||
with:
|
||||
config-name: 'release-drafter.yml'
|
||||
name: "Release ${{ github.ref_name }}"
|
||||
tag: ${{ github.ref_name }}
|
||||
publish: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
26
.github/workflows/golangci-lint.yml
vendored
Normal file
26
.github/workflows/golangci-lint.yml
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
name: golangci-lint
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
golangci:
|
||||
name: lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: stable
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: v2.0
|
||||
only-new-issues: true
|
22
.github/workflows/reviewdog.yml
vendored
22
.github/workflows/reviewdog.yml
vendored
@ -1,22 +0,0 @@
|
||||
name: reviewdog
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
golangci-lint:
|
||||
name: runner / golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
- name: golangci-lint
|
||||
uses: reviewdog/action-golangci-lint@v2
|
||||
|
||||
- name: Setup reviewdog
|
||||
uses: reviewdog/action-setup@v1
|
||||
|
||||
- name: gofumpt -s with reviewdog
|
||||
env:
|
||||
REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
go install mvdan.cc/gofumpt@v0.2.0
|
||||
gofumpt -e -d . | \
|
||||
reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review
|
96
.github/workflows/tests.yml
vendored
96
.github/workflows/tests.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
sqlite:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest] # can not run in windows OS
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -41,8 +41,8 @@ jobs:
|
||||
mysql:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['mysql:latest', 'mysql:5.7']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -85,7 +85,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'mariadb:latest' ]
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -127,8 +127,8 @@ jobs:
|
||||
postgres:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -170,23 +170,21 @@ jobs:
|
||||
sqlserver:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
mssql:
|
||||
image: mcmoe/mssqldocker:latest
|
||||
image: mcr.microsoft.com/mssql/server:2022-latest
|
||||
env:
|
||||
TZ: Asia/Shanghai
|
||||
ACCEPT_EULA: Y
|
||||
SA_PASSWORD: LoremIpsum86
|
||||
MSSQL_DB: gorm
|
||||
MSSQL_USER: gorm
|
||||
MSSQL_PASSWORD: LoremIpsum86
|
||||
MSSQL_SA_PASSWORD: LoremIpsum86
|
||||
ports:
|
||||
- 9930:1433
|
||||
options: >-
|
||||
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
|
||||
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
|
||||
--health-start-period 10s
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
@ -208,13 +206,13 @@ jobs:
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
|
||||
|
||||
tidb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: [ 'v6.5.0' ]
|
||||
go: ['1.21', '1.20', '1.19']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ ubuntu-latest ]
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
@ -242,3 +240,71 @@ jobs:
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
|
||||
|
||||
gaussdb:
|
||||
strategy:
|
||||
matrix:
|
||||
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
|
||||
go: ['1.23', '1.24']
|
||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||
runs-on: ${{ matrix.platform }}
|
||||
|
||||
services:
|
||||
gaussdb:
|
||||
image: ${{ matrix.dbversion }}
|
||||
env:
|
||||
# GaussDB has password limitations
|
||||
GS_PASSWORD: Gaussdb@123
|
||||
TZ: Asia/Shanghai
|
||||
ports:
|
||||
- 9950:5432
|
||||
|
||||
steps:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Waiting for GaussDB to be ready
|
||||
run: |
|
||||
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
|
||||
if [ -z "$container_name" ]; then
|
||||
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
|
||||
exit 1
|
||||
fi
|
||||
max_retries=12
|
||||
retry_count=0
|
||||
if [ -t 0 ]; then
|
||||
TTY_FLAG="-t"
|
||||
else
|
||||
TTY_FLAG=""
|
||||
fi
|
||||
while [ $retry_count -lt $max_retries ]; do
|
||||
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
|
||||
then
|
||||
echo "Creating database gorm..."
|
||||
sql_file='/tmp/create_database.sql'
|
||||
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
|
||||
docker cp "${sql_file}" "${container_name}":"${sql_file}"
|
||||
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
|
||||
echo "Database initialization completed."
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
|
||||
sleep 10
|
||||
((++retry_count))
|
||||
done
|
||||
exit 0
|
||||
|
||||
- name: go mod package cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
|
||||
|
||||
- name: Tests
|
||||
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
|
@ -1,7 +1,9 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: standard
|
||||
enable:
|
||||
- cyclop
|
||||
- exportloopref
|
||||
- gocritic
|
||||
- gosec
|
||||
- ineffassign
|
||||
@ -9,12 +11,9 @@ linters:
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- goimports
|
||||
- whitespace
|
||||
|
||||
linters-settings:
|
||||
whitespace:
|
||||
multi-func: true
|
||||
goimports:
|
||||
local-prefixes: gorm.io/gorm
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- goimports
|
||||
|
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to participate in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community includes:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period. This
|
||||
includes avoiding interactions in community spaces and external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any interaction or public
|
||||
communication with the community for a specified period. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
@ -396,6 +396,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
@ -433,6 +437,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !rv.CanAddr() {
|
||||
association.Error = ErrInvalidValue
|
||||
return
|
||||
}
|
||||
appendToFieldValues(rv.Addr())
|
||||
}
|
||||
|
||||
@ -510,6 +518,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
@ -531,6 +542,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||
if association.Error != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
|
19
callbacks.go
19
callbacks.go
@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
@ -126,7 +126,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
@ -195,7 +195,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
@ -268,11 +268,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
fieldType = reflect.PointerTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
|
@ -53,9 +53,13 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
if field.Readable {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
if len(fromColumns) > 0 {
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -89,6 +93,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@ -103,6 +111,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
@ -111,8 +125,11 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
|
||||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
|
||||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
@ -121,8 +138,11 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
db.AddError(err)
|
||||
if !supportReturning {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -142,6 +162,11 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
@ -293,13 +318,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -322,7 +349,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
@ -351,7 +378,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
71
callbacks/create_test.go
Normal file
71
callbacks/create_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var schemaCache = &sync.Map{}
|
||||
|
||||
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
|
||||
type user struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
Email string `gorm:"default:(-)"`
|
||||
Age int `gorm:"default:(-)"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("parse schema error: %v, is not expected", err)
|
||||
return
|
||||
}
|
||||
dest := []*user{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "alice",
|
||||
Email: "email",
|
||||
Age: 18,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "bob",
|
||||
Email: "email",
|
||||
Age: 19,
|
||||
},
|
||||
}
|
||||
stmt := &gorm.Statement{
|
||||
DB: &gorm.DB{
|
||||
Config: &gorm.Config{
|
||||
NowFunc: func() time.Time { return time.Time{} },
|
||||
},
|
||||
Statement: &gorm.Statement{
|
||||
Settings: sync.Map{},
|
||||
Schema: s,
|
||||
},
|
||||
},
|
||||
ReflectValue: reflect.ValueOf(dest),
|
||||
Dest: dest,
|
||||
}
|
||||
|
||||
stmt.Schema = s
|
||||
|
||||
values := ConvertToCreateValues(stmt)
|
||||
expected := clause.Values{
|
||||
// column has value + defaultValue column has value (which should have a stable order)
|
||||
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
|
||||
Values: [][]interface{}{
|
||||
{"alice", "email", 18, 1},
|
||||
{"bob", "email", 19, 2},
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(expected, values) {
|
||||
t.Errorf("expected: %v got %v", expected, values)
|
||||
}
|
||||
}
|
@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if !ok {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, mode)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
@ -103,11 +103,11 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
joinNames := strings.SplitN(join, ".", 2)
|
||||
if len(joinNames) == 2 {
|
||||
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||
join0, join1, cut := strings.Cut(join, ".")
|
||||
if cut {
|
||||
if _, ok := relationships.Relations[join0]; ok && name == join0 {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, joinNames[1])
|
||||
nestedJoins = append(nestedJoins, join1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -123,14 +123,26 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if rv.Len() > 0 {
|
||||
reflectValue := rel.FieldSchema.MakeSlice().Elem()
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if frv.Kind() != reflect.Ptr {
|
||||
reflectValue = reflect.Append(reflectValue, frv.Addr())
|
||||
} else {
|
||||
if frv.IsNil() {
|
||||
continue
|
||||
}
|
||||
reflectValue = reflect.Append(reflectValue, frv)
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
case reflect.Struct, reflect.Pointer:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
@ -263,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
@ -271,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||
if len(inlineConds) > 0 {
|
||||
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||
}
|
||||
|
||||
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, 0)
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
guessNestedRelations = append(guessNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
relations = guessNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
if join.Expression != nil {
|
||||
return clause.Join{
|
||||
Type: join.JoinType,
|
||||
Expression: join.Expression,
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
for idx, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
curAliasName := rel.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||
aliasName := curAliasName
|
||||
if idx == len(relations)-1 && join.Alias != "" {
|
||||
aliasName = join.Alias
|
||||
}
|
||||
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||
specifiedRelationsName[curAliasName] = aliasName
|
||||
}
|
||||
|
||||
parentTableName = curAliasName
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
@ -286,7 +297,11 @@ func Preload(db *gorm.DB) {
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
db.Statement.Joins = nil
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause := db.Statement.Clauses["FROM"]
|
||||
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
|
||||
db.Statement.Clauses["FROM"] = fromClause
|
||||
}
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
|
@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
db.AddError(rows.Close())
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
|
||||
if db.Statement.Result != nil {
|
||||
db.Statement.Result.Result = result
|
||||
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -234,7 +243,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
@ -268,7 +277,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
|
@ -185,6 +185,13 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
|
||||
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.ColumnMapping = m
|
||||
return
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
@ -299,10 +306,16 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
|
||||
// {Column: clause.Column{Name: "name"}, Desc: true},
|
||||
// {Column: clause.Column{Name: "age"}, Desc: true},
|
||||
// }})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := value.(type) {
|
||||
case clause.OrderBy:
|
||||
tx.Statement.AddClause(v)
|
||||
case clause.OrderByColumn:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{v},
|
||||
@ -429,6 +442,16 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Unscoped disables the global scope of soft deletion in a query.
|
||||
// By default, GORM uses soft deletion, marking records as "deleted"
|
||||
// by setting a timestamp on a specific field (e.g., `deleted_at`).
|
||||
// Unscoped allows queries to include records marked as deleted,
|
||||
// overriding the soft deletion behavior.
|
||||
// Example:
|
||||
//
|
||||
// var users []User
|
||||
// db.Unscoped().Find(&users)
|
||||
// // Retrieves all users, including deleted ones.
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
|
@ -1,5 +1,7 @@
|
||||
package clause
|
||||
|
||||
import "gorm.io/gorm/utils"
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
@ -9,6 +11,30 @@ const (
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
type JoinTarget struct {
|
||||
Type JoinType
|
||||
Association string
|
||||
Subquery Expression
|
||||
Table string
|
||||
}
|
||||
|
||||
func Has(name string) JoinTarget {
|
||||
return JoinTarget{Type: InnerJoin, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) Association(name string) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name}
|
||||
}
|
||||
|
||||
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||
}
|
||||
|
||||
func (jt JoinTarget) As(name string) JoinTarget {
|
||||
jt.Table = name
|
||||
return jt
|
||||
}
|
||||
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
@ -18,6 +44,12 @@ type Join struct {
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func JoinTable(names ...string) Table {
|
||||
return Table{
|
||||
Name: utils.JoinNestedRelationNames(names),
|
||||
}
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
if join.Expression != nil {
|
||||
join.Expression.Build(builder)
|
||||
|
@ -26,9 +26,12 @@ func (returning Returning) Build(builder Builder) {
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (returning Returning) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Returning); ok {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
|
||||
if v.Columns != nil {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
} else {
|
||||
returning.Columns = nil
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = returning
|
||||
}
|
||||
|
@ -26,6 +26,22 @@ func TestReturning(t *testing.T) {
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Returning{}, clause.Returning{
|
||||
[]clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
}},
|
||||
"SELECT * FROM `users` RETURNING *", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
|
||||
[]clause.Column{clause.PrimaryColumn},
|
||||
}, clause.Returning{
|
||||
[]clause.Column{{Name: "name"}, {Name: "age"}},
|
||||
}, clause.Returning{}},
|
||||
"SELECT * FROM `users` RETURNING *", nil,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -215,7 +215,12 @@ func (not NotConditions) Build(builder Builder) {
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
switch c.(type) {
|
||||
case OrConditions:
|
||||
builder.WriteString(OrWithSpace)
|
||||
default:
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
|
@ -113,6 +113,22 @@ func TestWhere(t *testing.T) {
|
||||
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
|
||||
[]interface{}{100, 60},
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Not(clause.AndConditions{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
|
||||
clause.Gt{Column: "age", Value: 18},
|
||||
}}, clause.OrConditions{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Lt{Column: "score", Value: 100},
|
||||
},
|
||||
}),
|
||||
}}},
|
||||
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
|
||||
[]interface{}{"1", 18, 100},
|
||||
},
|
||||
}
|
||||
|
||||
for idx, result := range results {
|
||||
|
@ -49,4 +49,6 @@ var (
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
// ErrCheckConstraintViolated occurs when there is a check constraint violation
|
||||
ErrCheckConstraintViolated = errors.New("violates check constraint")
|
||||
)
|
||||
|
@ -1,9 +1,11 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
@ -623,14 +625,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
spID := new(maphash.Hash).Sum64()
|
||||
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
db.RollbackTo(fmt.Sprintf("sp%d", spID))
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -671,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
ctx := tx.Statement.Context
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
if db.Config.DefaultTransactionTimeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
605
generics.go
Normal file
605
generics.go
Normal file
@ -0,0 +1,605 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type result struct {
|
||||
Result sql.Result
|
||||
RowsAffected int64
|
||||
}
|
||||
|
||||
func (info *result) ModifyStatement(stmt *Statement) {
|
||||
stmt.Result = info
|
||||
}
|
||||
|
||||
// Build implements clause.Expression interface
|
||||
func (result) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func WithResult() *result {
|
||||
return &result{}
|
||||
}
|
||||
|
||||
type Interface[T any] interface {
|
||||
Raw(sql string, values ...interface{}) ExecInterface[T]
|
||||
Exec(ctx context.Context, sql string, values ...interface{}) error
|
||||
CreateInterface[T]
|
||||
}
|
||||
|
||||
type CreateInterface[T any] interface {
|
||||
ChainInterface[T]
|
||||
Table(name string, args ...interface{}) CreateInterface[T]
|
||||
Create(ctx context.Context, r *T) error
|
||||
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
|
||||
}
|
||||
|
||||
type ChainInterface[T any] interface {
|
||||
ExecInterface[T]
|
||||
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
|
||||
Where(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Not(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Limit(offset int) ChainInterface[T]
|
||||
Offset(offset int) ChainInterface[T]
|
||||
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
|
||||
Select(query string, args ...interface{}) ChainInterface[T]
|
||||
Omit(columns ...string) ChainInterface[T]
|
||||
MapColumns(m map[string]string) ChainInterface[T]
|
||||
Distinct(args ...interface{}) ChainInterface[T]
|
||||
Group(name string) ChainInterface[T]
|
||||
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||
Order(value interface{}) ChainInterface[T]
|
||||
|
||||
Build(builder clause.Builder)
|
||||
|
||||
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||
Count(ctx context.Context, column string) (result int64, err error)
|
||||
}
|
||||
|
||||
type ExecInterface[T any] interface {
|
||||
Scan(ctx context.Context, r interface{}) error
|
||||
First(context.Context) (T, error)
|
||||
Last(ctx context.Context) (T, error)
|
||||
Take(context.Context) (T, error)
|
||||
Find(ctx context.Context) ([]T, error)
|
||||
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
|
||||
Row(ctx context.Context) *sql.Row
|
||||
Rows(ctx context.Context) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type JoinBuilder interface {
|
||||
Select(...string) JoinBuilder
|
||||
Omit(...string) JoinBuilder
|
||||
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||
}
|
||||
|
||||
type PreloadBuilder interface {
|
||||
Select(...string) PreloadBuilder
|
||||
Omit(...string) PreloadBuilder
|
||||
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||
Limit(offset int) PreloadBuilder
|
||||
Offset(offset int) PreloadBuilder
|
||||
Order(value interface{}) PreloadBuilder
|
||||
LimitPerRecord(num int) PreloadBuilder
|
||||
}
|
||||
|
||||
type op func(*DB) *DB
|
||||
|
||||
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||
v := &g[T]{
|
||||
db: db,
|
||||
ops: make([]op, 0, 5),
|
||||
}
|
||||
|
||||
if len(opts) > 0 {
|
||||
v.ops = append(v.ops, func(db *DB) *DB {
|
||||
return db.Clauses(opts...)
|
||||
})
|
||||
}
|
||||
|
||||
v.createG = &createG[T]{
|
||||
chainG: chainG[T]{
|
||||
execG: execG[T]{g: v},
|
||||
},
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type g[T any] struct {
|
||||
*createG[T]
|
||||
db *DB
|
||||
ops []op
|
||||
}
|
||||
|
||||
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||
db := g.db
|
||||
if !db.DryRun {
|
||||
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||
}
|
||||
|
||||
for _, op := range g.ops {
|
||||
db = op(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||
return execG[T]{g: &g[T]{
|
||||
db: c.db,
|
||||
ops: append(c.ops, func(db *DB) *DB {
|
||||
return db.Raw(sql, values...)
|
||||
}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||
return c.apply(ctx).Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
type createG[T any] struct {
|
||||
chainG[T]
|
||||
}
|
||||
|
||||
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||
return createG[T]{c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})}
|
||||
}
|
||||
|
||||
func (c createG[T]) Create(ctx context.Context, r *T) error {
|
||||
return c.g.apply(ctx).Create(r).Error
|
||||
}
|
||||
|
||||
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
|
||||
}
|
||||
|
||||
type chainG[T any] struct {
|
||||
execG[T]
|
||||
}
|
||||
|
||||
func (c chainG[T]) getInstance() *DB {
|
||||
var r T
|
||||
return c.g.apply(context.Background()).Model(r).getInstance()
|
||||
}
|
||||
|
||||
func (c chainG[T]) with(v op) chainG[T] {
|
||||
return chainG[T]{
|
||||
execG: execG[T]{g: &g[T]{
|
||||
db: c.g.db,
|
||||
ops: append(append([]op(nil), c.g.ops...), v),
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
for _, fc := range scopes {
|
||||
fc(db.Statement)
|
||||
}
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Table(name, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Where(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Not(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Or(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Limit(offset)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Offset(offset)
|
||||
})
|
||||
}
|
||||
|
||||
type joinBuilder struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
type preloadBuilder struct {
|
||||
limitPerRecord int
|
||||
db *DB
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||
q.db.Where(query, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||
q.db.Select(columns)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||
q.db.Omit(columns...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||
q.db.Limit(limit)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||
q.db.Offset(offset)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||
q.db.Order(value)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||
q.limitPerRecord = num
|
||||
return q
|
||||
}
|
||||
|
||||
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
if jt.Table == "" {
|
||||
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||
}
|
||||
|
||||
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||
if on != nil {
|
||||
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
j := join{
|
||||
Name: jt.Association,
|
||||
Alias: jt.Table,
|
||||
Selects: q.db.Statement.Selects,
|
||||
Omits: q.db.Statement.Omits,
|
||||
JoinType: jt.Type,
|
||||
}
|
||||
|
||||
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
|
||||
if jt.Subquery != nil {
|
||||
joinType := j.JoinType
|
||||
if joinType == "" {
|
||||
joinType = clause.LeftJoin
|
||||
}
|
||||
|
||||
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
|
||||
stmt := db.getInstance().Statement
|
||||
if len(j.Selects) == 0 {
|
||||
j.Selects = stmt.Selects
|
||||
}
|
||||
if len(j.Omits) == 0 {
|
||||
j.Omits = stmt.Omits
|
||||
}
|
||||
}
|
||||
|
||||
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||
|
||||
if j.On != nil {
|
||||
expr.SQL += " ON ?"
|
||||
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||
}
|
||||
|
||||
j.Expression = expr
|
||||
}
|
||||
|
||||
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||
})
|
||||
return db
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Select(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Omit(columns...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.MapColumns(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Distinct(args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Group(name string) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Group(name)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Having(query, args...)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Order(value)
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
|
||||
return c.with(func(db *DB) *DB {
|
||||
return db.Preload(association, func(tx *DB) *DB {
|
||||
q := preloadBuilder{db: tx.getInstance()}
|
||||
if query != nil {
|
||||
if err := query(&q); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||
if !ok {
|
||||
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||
relationships := db.Statement.Schema.Relationships
|
||||
for _, field := range preloadFields {
|
||||
var ok bool
|
||||
relation, ok = relationships.Relations[field]
|
||||
if ok {
|
||||
relationships = relation.FieldSchema.Relationships
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if q.limitPerRecord > 0 {
|
||||
if relation.JoinTable != nil {
|
||||
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
|
||||
return tx
|
||||
}
|
||||
|
||||
refColumns := []clause.Column{}
|
||||
for _, rel := range relation.References {
|
||||
if rel.OwnPrimaryKey {
|
||||
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||
}
|
||||
}
|
||||
|
||||
if len(refColumns) != 0 {
|
||||
selectExpr := clause.CommaExpression{}
|
||||
for _, column := range q.db.Statement.Selects {
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||
}
|
||||
|
||||
if len(selectExpr.Exprs) == 0 {
|
||||
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||
}
|
||||
|
||||
partitionBy := clause.CommaExpression{}
|
||||
for _, column := range refColumns {
|
||||
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||
}
|
||||
|
||||
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||
vars := []interface{}{partitionBy}
|
||||
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||
vars = append(vars, orderBy)
|
||||
} else {
|
||||
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||
}})
|
||||
}
|
||||
vars = append(vars, rnnColumn)
|
||||
|
||||
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||
|
||||
q.db.Clauses(clause.Select{Expression: selectExpr})
|
||||
|
||||
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||
}
|
||||
}
|
||||
|
||||
return q.db
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||
r := new(T)
|
||||
res := c.g.apply(ctx).Delete(r)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||
var r T
|
||||
res := c.g.apply(ctx).Model(r).Update(name, value)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||
res := c.g.apply(ctx).Updates(t)
|
||||
return int(res.RowsAffected), res.Error
|
||||
}
|
||||
|
||||
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||
var r T
|
||||
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (c chainG[T]) Build(builder clause.Builder) {
|
||||
subdb := c.getInstance()
|
||||
subdb.Logger = logger.Discard
|
||||
subdb.DryRun = true
|
||||
|
||||
if stmt, ok := builder.(*Statement); ok {
|
||||
if subdb.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = subdb.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
builder.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
}
|
||||
}
|
||||
|
||||
type execG[T any] struct {
|
||||
g *g[T]
|
||||
}
|
||||
|
||||
func (g execG[T]) First(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).First(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Model(r).Find(result).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (g execG[T]) Last(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Last(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Take(ctx context.Context) (T, error) {
|
||||
var r T
|
||||
err := g.g.apply(ctx).Take(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||
var r []T
|
||||
err := g.g.apply(ctx).Find(&r).Error
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||
var data []T
|
||||
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||
return fc(data, batch)
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (g execG[T]) Row(ctx context.Context) *sql.Row {
|
||||
return g.g.apply(ctx).Row()
|
||||
}
|
||||
|
||||
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||
return g.g.apply(ctx).Rows()
|
||||
}
|
1
go.mod
1
go.mod
@ -5,4 +5,5 @@ go 1.18
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
golang.org/x/text v0.20.0
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -2,3 +2,5 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
|
48
gorm.go
48
gorm.go
@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
SkipDefaultTransaction bool
|
||||
DefaultTransactionTimeout time.Duration
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
@ -34,6 +36,11 @@ type Config struct {
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// PrepareStmt cache support LRU expired,
|
||||
// default maxsize=int64 Max value and ttl=1h
|
||||
PrepareStmtMaxSize int
|
||||
PrepareStmtTTL time.Duration
|
||||
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
@ -50,6 +57,8 @@ type Config struct {
|
||||
CreateBatchSize int
|
||||
// TranslateError enabling error translation
|
||||
TranslateError bool
|
||||
// PropagateUnscoped propagate Unscoped to every other nested statement
|
||||
PropagateUnscoped bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
@ -110,6 +119,7 @@ type Session struct {
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
PropagateUnscoped bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
@ -127,12 +137,24 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
return isConfig && !isConfig2
|
||||
})
|
||||
|
||||
if len(opts) > 0 {
|
||||
if c, ok := opts[0].(*Config); ok {
|
||||
config = c
|
||||
} else {
|
||||
opts = append([]Option{config}, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
var skipAfterInitialize bool
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if applyErr := opt.Apply(config); applyErr != nil {
|
||||
return nil, applyErr
|
||||
}
|
||||
defer func(opt Option) {
|
||||
if skipAfterInitialize {
|
||||
return
|
||||
}
|
||||
if errr := opt.AfterInitialize(db); errr != nil {
|
||||
err = errr
|
||||
}
|
||||
@ -180,16 +202,25 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
|
||||
if err != nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
|
||||
// DB is not initialized, so we skip AfterInitialize
|
||||
skipAfterInitialize = true
|
||||
return
|
||||
}
|
||||
|
||||
if config.TranslateError {
|
||||
if _, ok := db.Dialector.(ErrorTranslator); !ok {
|
||||
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
@ -241,6 +272,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.PropagateUnscoped {
|
||||
txConfig.PropagateUnscoped = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
@ -256,7 +291,7 @@ func (db *DB) Session(config *Session) *DB {
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
@ -409,6 +444,9 @@ func (db *DB) getInstance() *DB {
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
}
|
||||
if db.Config.PropagateUnscoped {
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
tx.Statement = db.Statement.clone()
|
||||
@ -499,7 +537,7 @@ func (db *DB) Use(plugin Plugin) error {
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
|
493
internal/lru/lru.go
Normal file
493
internal/lru/lru.go
Normal file
@ -0,0 +1,493 @@
|
||||
package lru
|
||||
|
||||
// golang -lru
|
||||
// https://github.com/hashicorp/golang-lru
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EvictCallback is used to get a callback when a cache entry is evicted
|
||||
type EvictCallback[K comparable, V any] func(key K, value V)
|
||||
|
||||
// LRU implements a thread-safe LRU with expirable entries.
|
||||
type LRU[K comparable, V any] struct {
|
||||
size int
|
||||
evictList *LruList[K, V]
|
||||
items map[K]*Entry[K, V]
|
||||
onEvict EvictCallback[K, V]
|
||||
|
||||
// expirable options
|
||||
mu sync.Mutex
|
||||
ttl time.Duration
|
||||
done chan struct{}
|
||||
|
||||
// buckets for expiration
|
||||
buckets []bucket[K, V]
|
||||
// uint8 because it's number between 0 and numBuckets
|
||||
nextCleanupBucket uint8
|
||||
}
|
||||
|
||||
// bucket is a container for holding entries to be expired
|
||||
type bucket[K comparable, V any] struct {
|
||||
entries map[K]*Entry[K, V]
|
||||
newestEntry time.Time
|
||||
}
|
||||
|
||||
// noEvictionTTL - very long ttl to prevent eviction
|
||||
const noEvictionTTL = time.Hour * 24 * 365 * 10
|
||||
|
||||
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
|
||||
// casting it as uint8 explicitly requires type conversions in multiple places
|
||||
const numBuckets = 100
|
||||
|
||||
// NewLRU returns a new thread-safe cache with expirable entries.
|
||||
//
|
||||
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
|
||||
//
|
||||
// Providing 0 TTL turns expiring off.
|
||||
//
|
||||
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
|
||||
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
|
||||
if size < 0 {
|
||||
size = 0
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = noEvictionTTL
|
||||
}
|
||||
|
||||
res := LRU[K, V]{
|
||||
ttl: ttl,
|
||||
size: size,
|
||||
evictList: NewList[K, V](),
|
||||
items: make(map[K]*Entry[K, V]),
|
||||
onEvict: onEvict,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// initialize the buckets
|
||||
res.buckets = make([]bucket[K, V], numBuckets)
|
||||
for i := 0; i < numBuckets; i++ {
|
||||
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
|
||||
}
|
||||
|
||||
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
|
||||
//
|
||||
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
|
||||
// it's decided to add functionality to close it in the version later than v2.
|
||||
if res.ttl != noEvictionTTL {
|
||||
go func(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(res.ttl / numBuckets)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
res.deleteExpired()
|
||||
}
|
||||
}
|
||||
}(res.done)
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
// Purge clears the cache completely.
|
||||
// onEvict is called for each evicted key.
|
||||
func (c *LRU[K, V]) Purge() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for k, v := range c.items {
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(k, v.Value)
|
||||
}
|
||||
delete(c.items, k)
|
||||
}
|
||||
for _, b := range c.buckets {
|
||||
for _, ent := range b.entries {
|
||||
delete(b.entries, ent.Key)
|
||||
}
|
||||
}
|
||||
c.evictList.Init()
|
||||
}
|
||||
|
||||
// Add adds a value to the cache. Returns true if an eviction occurred.
|
||||
// Returns false if there was no eviction: the item was already in the cache,
|
||||
// or the size was not exceeded.
|
||||
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
now := time.Now()
|
||||
|
||||
// Check for existing item
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.evictList.MoveToFront(ent)
|
||||
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
|
||||
ent.Value = value
|
||||
ent.ExpiresAt = now.Add(c.ttl)
|
||||
c.addToBucket(ent)
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new item
|
||||
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
|
||||
c.items[key] = ent
|
||||
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
|
||||
|
||||
evict := c.size > 0 && c.evictList.Length() > c.size
|
||||
// Verify size not exceeded
|
||||
if evict {
|
||||
c.removeOldest()
|
||||
}
|
||||
return evict
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
c.evictList.MoveToFront(ent)
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Contains checks if a key is in the cache, without updating the recent-ness
|
||||
// or deleting it for being stale.
|
||||
func (c *LRU[K, V]) Contains(key K) (ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, ok = c.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Peek returns the key value (or undefined if not found) without updating
|
||||
// the "recently used"-ness of the key.
|
||||
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
var ent *Entry[K, V]
|
||||
if ent, ok = c.items[key]; ok {
|
||||
// Expired item check
|
||||
if time.Now().After(ent.ExpiresAt) {
|
||||
return value, false
|
||||
}
|
||||
return ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache, returning if the
|
||||
// key was contained.
|
||||
func (c *LRU[K, V]) Remove(key K) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent, ok := c.items[key]; ok {
|
||||
c.removeElement(ent)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetOldest returns the oldest entry
|
||||
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
return ent.Key, ent.Value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *LRU[K, V]) KeyValues() map[K]V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
maps := make(map[K]V)
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
maps[ent.Key] = ent.Value
|
||||
// keys = append(keys, ent.Key)
|
||||
}
|
||||
return maps
|
||||
}
|
||||
|
||||
// Keys returns a slice of the keys in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Keys() []K {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
keys := make([]K, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ent.Key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Values returns a slice of the values in the cache, from oldest to newest.
|
||||
// Expired entries are filtered out.
|
||||
func (c *LRU[K, V]) Values() []V {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
values := make([]V, 0, len(c.items))
|
||||
now := time.Now()
|
||||
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
|
||||
if now.After(ent.ExpiresAt) {
|
||||
continue
|
||||
}
|
||||
values = append(values, ent.Value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *LRU[K, V]) Len() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.evictList.Length()
|
||||
}
|
||||
|
||||
// Resize changes the cache size. Size of 0 means unlimited.
|
||||
func (c *LRU[K, V]) Resize(size int) (evicted int) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if size <= 0 {
|
||||
c.size = 0
|
||||
return 0
|
||||
}
|
||||
diff := c.evictList.Length() - size
|
||||
if diff < 0 {
|
||||
diff = 0
|
||||
}
|
||||
for i := 0; i < diff; i++ {
|
||||
c.removeOldest()
|
||||
}
|
||||
c.size = size
|
||||
return diff
|
||||
}
|
||||
|
||||
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
|
||||
// func (c *LRU[K, V]) Close() {
|
||||
// c.mu.Lock()
|
||||
// defer c.mu.Unlock()
|
||||
// select {
|
||||
// case <-c.done:
|
||||
// return
|
||||
// default:
|
||||
// }
|
||||
// close(c.done)
|
||||
// }
|
||||
|
||||
// removeOldest removes the oldest item from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeOldest() {
|
||||
if ent := c.evictList.Back(); ent != nil {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
}
|
||||
|
||||
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
|
||||
c.evictList.Remove(e)
|
||||
delete(c.items, e.Key)
|
||||
c.removeFromBucket(e)
|
||||
if c.onEvict != nil {
|
||||
c.onEvict(e.Key, e.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
|
||||
// in it to expire first.
|
||||
func (c *LRU[K, V]) deleteExpired() {
|
||||
c.mu.Lock()
|
||||
bucketIdx := c.nextCleanupBucket
|
||||
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
|
||||
// wait for newest entry to expire before cleanup without holding lock
|
||||
if timeToExpire > 0 {
|
||||
c.mu.Unlock()
|
||||
time.Sleep(timeToExpire)
|
||||
c.mu.Lock()
|
||||
}
|
||||
for _, ent := range c.buckets[bucketIdx].entries {
|
||||
c.removeElement(ent)
|
||||
}
|
||||
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
|
||||
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
|
||||
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
|
||||
e.ExpireBucket = bucketID
|
||||
c.buckets[bucketID].entries[e.Key] = e
|
||||
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
|
||||
c.buckets[bucketID].newestEntry = e.ExpiresAt
|
||||
}
|
||||
}
|
||||
|
||||
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
|
||||
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
|
||||
delete(c.buckets[e.ExpireBucket].entries, e.Key)
|
||||
}
|
||||
|
||||
// Cap returns the capacity of the cache
|
||||
func (c *LRU[K, V]) Cap() int {
|
||||
return c.size
|
||||
}
|
||||
|
||||
// Entry is an LRU Entry
|
||||
type Entry[K comparable, V any] struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *Entry[K, V]
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *LruList[K, V]
|
||||
|
||||
// The LRU Key of this element.
|
||||
Key K
|
||||
|
||||
// The Value stored with this element.
|
||||
Value V
|
||||
|
||||
// The time this element would be cleaned up, optional
|
||||
ExpiresAt time.Time
|
||||
|
||||
// The expiry bucket item was put in, optional
|
||||
ExpireBucket uint8
|
||||
}
|
||||
|
||||
// PrevEntry returns the previous list element or nil.
|
||||
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LruList represents a doubly linked list.
|
||||
// The zero Value for LruList is an empty list ready to use.
|
||||
type LruList[K comparable, V any] struct {
|
||||
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list Length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *LruList[K, V]) Init() *LruList[K, V] {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// NewList returns an initialized list.
|
||||
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
|
||||
|
||||
// Length returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *LruList[K, V]) Length() int { return l.len }
|
||||
|
||||
// Back returns the last element of list l or nil if the list is empty.
|
||||
func (l *LruList[K, V]) Back() *Entry[K, V] {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// lazyInit lazily initializes a zero List Value.
|
||||
func (l *LruList[K, V]) lazyInit() {
|
||||
if l.root.next == nil {
|
||||
l.Init()
|
||||
}
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
|
||||
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
|
||||
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
|
||||
}
|
||||
|
||||
// Remove removes e from its list, decrements l.len
|
||||
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// move moves e to next to at.
|
||||
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
|
||||
if e == at {
|
||||
return
|
||||
}
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
|
||||
e.prev = at
|
||||
e.next = at.next
|
||||
e.prev.next = e
|
||||
e.next.prev = e
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, time.Time{}, &l.root)
|
||||
}
|
||||
|
||||
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
|
||||
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
|
||||
l.lazyInit()
|
||||
return l.insertValue(k, v, expiresAt, &l.root)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
// The element must not be nil.
|
||||
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.move(e, &l.root)
|
||||
}
|
183
internal/stmt_store/stmt_store.go
Normal file
183
internal/stmt_store/stmt_store.go
Normal file
@ -0,0 +1,183 @@
|
||||
package stmt_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/lru"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Error() error {
|
||||
return stmt.prepareErr
|
||||
}
|
||||
|
||||
func (stmt *Stmt) Close() error {
|
||||
<-stmt.prepared
|
||||
|
||||
if stmt.Stmt != nil {
|
||||
return stmt.Stmt.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
|
||||
// This interface provides methods for creating new statements, retrieving all cache keys,
|
||||
// getting cached statements, setting cached statements, and deleting cached statements.
|
||||
type Store interface {
|
||||
// New creates a new Stmt object and caches it.
|
||||
// Parameters:
|
||||
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
|
||||
// connPool: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
// Returns:
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
|
||||
|
||||
// Keys returns a slice of all cache keys in the store.
|
||||
Keys() []string
|
||||
|
||||
// Get retrieves a Stmt object from the store based on the given key.
|
||||
// Parameters:
|
||||
// key: The key used to look up the Stmt object.
|
||||
// Returns:
|
||||
// *Stmt: The found Stmt object, or nil if not found.
|
||||
// bool: Indicates whether the corresponding Stmt object was successfully found.
|
||||
Get(key string) (*Stmt, bool)
|
||||
|
||||
// Set stores the given Stmt object in the store and associates it with the specified key.
|
||||
// Parameters:
|
||||
// key: The key used to associate the Stmt object.
|
||||
// value: The Stmt object to be stored.
|
||||
Set(key string, value *Stmt)
|
||||
|
||||
// Delete removes the Stmt object corresponding to the specified key from the store.
|
||||
// Parameters:
|
||||
// key: The key associated with the Stmt object to be deleted.
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// defaultMaxSize defines the default maximum capacity of the cache.
|
||||
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
|
||||
// the cache can theoretically store as many elements as possible.
|
||||
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
|
||||
const (
|
||||
defaultMaxSize = math.MaxInt
|
||||
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
|
||||
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
|
||||
defaultTTL = time.Hour * 24
|
||||
)
|
||||
|
||||
// New creates and returns a new Store instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
|
||||
// it defaults to defaultMaxSize.
|
||||
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
|
||||
// it defaults to defaultTTL.
|
||||
//
|
||||
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
|
||||
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
|
||||
// to release associated resources.
|
||||
//
|
||||
// Returns:
|
||||
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
|
||||
// eviction callback, and TTL.
|
||||
func New(size int, ttl time.Duration) Store {
|
||||
if size <= 0 {
|
||||
size = defaultMaxSize
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = defaultTTL
|
||||
}
|
||||
|
||||
onEvicted := func(k string, v *Stmt) {
|
||||
if v != nil {
|
||||
go v.Close()
|
||||
}
|
||||
}
|
||||
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
|
||||
}
|
||||
|
||||
type lruStore struct {
|
||||
lru *lru.LRU[string, *Stmt]
|
||||
}
|
||||
|
||||
func (s *lruStore) Keys() []string {
|
||||
return s.lru.Keys()
|
||||
}
|
||||
|
||||
func (s *lruStore) Get(key string) (*Stmt, bool) {
|
||||
stmt, ok := s.lru.Get(key)
|
||||
if ok && stmt != nil {
|
||||
<-stmt.prepared
|
||||
}
|
||||
return stmt, ok
|
||||
}
|
||||
|
||||
func (s *lruStore) Set(key string, value *Stmt) {
|
||||
s.lru.Add(key, value)
|
||||
}
|
||||
|
||||
func (s *lruStore) Delete(key string) {
|
||||
s.lru.Remove(key)
|
||||
}
|
||||
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// New creates a new Stmt object for executing SQL queries.
|
||||
// It caches the Stmt object for future use and handles preparation and error states.
|
||||
// Parameters:
|
||||
//
|
||||
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
|
||||
// key: The key representing the SQL query, used for caching and preparing the statement.
|
||||
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
|
||||
// conn: A connection pool that provides database connections.
|
||||
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// *Stmt: A newly created statement object for executing SQL operations.
|
||||
// error: An error if the statement preparation fails.
|
||||
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
|
||||
// Create a Stmt object and set its Transaction property.
|
||||
// The prepared channel is used to synchronize the statement preparation state.
|
||||
cacheStmt := &Stmt{
|
||||
Transaction: isTransaction,
|
||||
prepared: make(chan struct{}),
|
||||
}
|
||||
// Cache the Stmt object with the associated key.
|
||||
s.Set(key, cacheStmt)
|
||||
// Unlock after completing initialization to prevent deadlocks.
|
||||
locker.Unlock()
|
||||
|
||||
// Ensure the prepared channel is closed after the function execution completes.
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Prepare the SQL statement using the provided connection.
|
||||
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
|
||||
if err != nil {
|
||||
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
|
||||
cacheStmt.prepareErr = err
|
||||
s.Delete(key)
|
||||
return &Stmt{}, err
|
||||
}
|
||||
|
||||
// Return the successfully prepared Stmt object.
|
||||
return cacheStmt, nil
|
||||
}
|
@ -80,6 +80,11 @@ var (
|
||||
})
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
|
||||
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
|
||||
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
return sql, params
|
||||
}
|
||||
)
|
||||
|
||||
// New initialize logger
|
||||
@ -211,3 +216,10 @@ func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (s
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
l.Err = err
|
||||
}
|
||||
|
||||
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if RecorderParamsFilter == nil {
|
||||
return sql, params
|
||||
}
|
||||
return RecorderParamsFilter(ctx, sql, params...)
|
||||
}
|
||||
|
@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
func isNumeric(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else if isNumeric(rv.Kind()) {
|
||||
if rv.CanInt() || rv.CanUint() {
|
||||
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||
} else {
|
||||
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||
}
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
|
@ -37,14 +37,18 @@ func format(v []byte, escaper string) string {
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
type role string
|
||||
type password []byte
|
||||
type intType int
|
||||
type floatType float64
|
||||
var (
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password("pass")
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
tt = now.MustParse("2020-02-23 11:10:10")
|
||||
myrole = role("admin")
|
||||
pwd = password("pass")
|
||||
jsVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
js = JSON(jsVal)
|
||||
esVal = []byte(`{"Name":"test","Val":"test"}`)
|
||||
es = ExampleStruct{Name: "test", Val: "test"}
|
||||
intVal intType = 1
|
||||
floatVal floatType = 1.23
|
||||
)
|
||||
|
||||
results := []struct {
|
||||
@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) {
|
||||
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
|
||||
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
|
||||
},
|
||||
{
|
||||
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
NumericRegexp: nil,
|
||||
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
|
||||
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, r := range results {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -126,6 +127,11 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -210,6 +216,11 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
@ -362,6 +373,9 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// avoid using the same name field
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
f := stmt.Schema.LookUpField(name)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||
@ -381,8 +395,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
// DropColumn drop value's `name` column
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@ -394,13 +410,15 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
// AlterColumn alter value's `field` column' type based on schema definition
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
@ -412,8 +430,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
name := field
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@ -428,12 +448,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
// RenameColumn rename value's field name from oldName to newName
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@ -452,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
// found, smart migrate
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
var (
|
||||
alterColumn bool
|
||||
isSameType = fullDataType == realDataType
|
||||
@ -491,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
// check precision
|
||||
if realDataType == "decimal" || realDataType == "numeric" &&
|
||||
regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
|
||||
precision, scale, ok := columnType.DecimalSize()
|
||||
if ok {
|
||||
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
|
||||
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
@ -502,8 +534,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
// not primary key & database is nullable
|
||||
if !field.PrimaryKey && nullable {
|
||||
// not primary key & current database is non-nullable(to be nullable)
|
||||
if !field.PrimaryKey && !nullable {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
@ -518,12 +550,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
|
||||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
|
||||
// default value not equal
|
||||
// not both null
|
||||
if currentDefaultNotNull || dvNotNull {
|
||||
alterColumn = true
|
||||
} else if currentDefaultNotNull || dvNotNull {
|
||||
switch field.GORMDataType {
|
||||
case schema.Time:
|
||||
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
|
||||
alterColumn = true
|
||||
}
|
||||
case schema.Bool:
|
||||
v1, _ := strconv.ParseBool(dv)
|
||||
v2, _ := strconv.ParseBool(field.DefaultValue)
|
||||
alterColumn = v1 != v2
|
||||
default:
|
||||
alterColumn = dv != field.DefaultValue
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -787,6 +825,9 @@ type BuildIndexOptionsInterface interface {
|
||||
// CreateIndex create index `name`
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
@ -819,8 +860,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
// DropIndex drop index `name`
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
|
||||
@ -832,8 +875,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
|
160
prepare_stmt.go
160
prepare_stmt.go
@ -3,33 +3,39 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/internal/stmt_store"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts map[string]*Stmt
|
||||
PreparedSQL []string
|
||||
Mux *sync.RWMutex
|
||||
Stmts stmt_store.Store
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
|
||||
//
|
||||
// Parameters:
|
||||
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
|
||||
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
|
||||
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
|
||||
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
|
||||
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
|
||||
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
|
||||
}
|
||||
}
|
||||
|
||||
// GetDBConn returns the underlying *sql.DB connection
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
@ -42,84 +48,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
// Close closes all prepared statements in the store
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, query := range db.PreparedSQL {
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
delete(db.Stmts, query)
|
||||
go stmt.Close()
|
||||
}
|
||||
for _, key := range db.Stmts.Keys() {
|
||||
db.Stmts.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
for _, stmt := range sdb.Stmts {
|
||||
go stmt.Close()
|
||||
}
|
||||
sdb.PreparedSQL = make([]string, 0, 100)
|
||||
sdb.Stmts = make(map[string]*Stmt)
|
||||
// Reset Deprecated use Close instead
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
// retry
|
||||
db.Mux.Lock()
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
if db.Stmts != nil {
|
||||
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
return stmt, stmt.Error()
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
|
||||
// cache preparing stmt first
|
||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||
db.Stmts[query] = &cacheStmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
db.Mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, err
|
||||
}
|
||||
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
db.Mux.Unlock()
|
||||
|
||||
return cacheStmt, nil
|
||||
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
@ -147,11 +110,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -161,12 +121,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
db.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
@ -180,6 +136,14 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Ping() error {
|
||||
conn, err := db.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
@ -207,12 +171,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -222,12 +182,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
tx.PreparedStmtDB.Stmts.Delete(query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
@ -240,3 +196,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Ping() error {
|
||||
conn, err := tx.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
33
scan.go
33
scan.go
@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
@ -15,7 +16,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
@ -23,7 +24,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
@ -131,6 +132,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||
)
|
||||
|
||||
if len(db.Statement.ColumnMapping) > 0 {
|
||||
for i, column := range columns {
|
||||
v, ok := db.Statement.ColumnMapping[column]
|
||||
if ok {
|
||||
columns[i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
@ -235,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||
for _, join := range db.Statement.Joins {
|
||||
if join.Alias == aliasName {
|
||||
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
@ -244,7 +262,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// lastest name is raw dbname
|
||||
// latest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
@ -257,9 +275,11 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -329,6 +349,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
}
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
|
||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
|
@ -56,6 +56,7 @@ type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
BindNames []string
|
||||
EmbeddedBindNames []string
|
||||
DataType DataType
|
||||
GORMDataType DataType
|
||||
PrimaryKey bool
|
||||
@ -112,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
Name: fieldStruct.Name,
|
||||
DBName: tagSetting["COLUMN"],
|
||||
BindNames: []string{fieldStruct.Name},
|
||||
EmbeddedBindNames: []string{fieldStruct.Name},
|
||||
FieldType: fieldStruct.Type,
|
||||
IndirectFieldType: fieldStruct.Type,
|
||||
StructField: fieldStruct,
|
||||
@ -316,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
switch DataType(strings.ToLower(val)) {
|
||||
lowerVal := DataType(strings.ToLower(val))
|
||||
switch lowerVal {
|
||||
case Bool, Int, Uint, Float, String, Time, Bytes:
|
||||
field.DataType = DataType(strings.ToLower(val))
|
||||
field.DataType = lowerVal
|
||||
default:
|
||||
field.DataType = DataType(val)
|
||||
}
|
||||
@ -403,6 +406,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
|
||||
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
|
||||
}
|
||||
// index is negative means is pointer
|
||||
if field.FieldType.Kind() == reflect.Struct {
|
||||
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
|
||||
@ -442,21 +448,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
|
||||
// create valuer, setter when parse struct
|
||||
func (field *Field) setupValuerAndSetter() {
|
||||
func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
|
||||
// Setup NewValuePool
|
||||
field.setupNewValuePool()
|
||||
|
||||
// ValueOf returns field's value and if it is zero
|
||||
fieldIndex := field.StructField.Index[0]
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
fieldValue := v.FieldByName(field.Name)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
fieldValue := v.Field(fieldIndex)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
default:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
fieldValue := v.FieldByName(field.Name)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
for _, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
@ -498,13 +513,20 @@ func (field *Field) setupValuerAndSetter() {
|
||||
|
||||
// ReflectValueOf returns field's reflect value
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||
return reflect.Indirect(value).Field(fieldIndex)
|
||||
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
return v.FieldByName(field.Name)
|
||||
}
|
||||
return v.Field(fieldIndex)
|
||||
}
|
||||
default:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Type() != modelType {
|
||||
return v.FieldByName(field.Name)
|
||||
}
|
||||
for idx, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
@ -664,7 +686,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -673,7 +695,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@ -738,7 +760,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
@ -991,6 +1013,6 @@ func (field *Field) setupNewValuePool() {
|
||||
}
|
||||
|
||||
if field.NewValuePool == nil {
|
||||
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
|
||||
field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType))
|
||||
}
|
||||
}
|
||||
|
@ -23,12 +23,13 @@ type IndexOption struct {
|
||||
Sort string // DESC, ASC
|
||||
Collate string
|
||||
Length int
|
||||
priority int
|
||||
Priority int
|
||||
}
|
||||
|
||||
// ParseIndexes parse schema indexes
|
||||
func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
indexes := map[string]Index{}
|
||||
func (schema *Schema) ParseIndexes() []*Index {
|
||||
indexesByName := map[string]*Index{}
|
||||
indexes := []*Index{}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
@ -38,7 +39,12 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
break
|
||||
}
|
||||
for _, index := range fieldIndexes {
|
||||
idx := indexes[index.Name]
|
||||
idx := indexesByName[index.Name]
|
||||
if idx == nil {
|
||||
idx = &Index{Name: index.Name}
|
||||
indexesByName[index.Name] = idx
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
@ -58,10 +64,8 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
return idx.Fields[i].priority < idx.Fields[j].priority
|
||||
return idx.Fields[i].Priority < idx.Fields[j].Priority
|
||||
})
|
||||
|
||||
indexes[index.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -78,12 +82,12 @@ func (schema *Schema) LookIndex(name string) *Index {
|
||||
indexes := schema.ParseIndexes()
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return &index
|
||||
return index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return &index
|
||||
return index
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -101,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.Index(tag, ",")
|
||||
idx = strings.IndexByte(tag, ',')
|
||||
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
|
||||
settings = ParseTagSetting(tagSetting, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
@ -111,17 +115,14 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
if idx != -1 {
|
||||
name = tag[0:idx]
|
||||
}
|
||||
|
||||
name = tag[0:idx]
|
||||
if name == "" {
|
||||
subName := field.Name
|
||||
const key = "COMPOSITE"
|
||||
if composite, found := settings[key]; found {
|
||||
if len(composite) == 0 || composite == key {
|
||||
err = fmt.Errorf(
|
||||
"The composite tag of %s.%s cannot be empty",
|
||||
"the composite tag of %s.%s cannot be empty",
|
||||
field.Schema.Name,
|
||||
field.Name)
|
||||
return
|
||||
@ -154,7 +155,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
priority: priority,
|
||||
Priority: priority,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
@ -21,6 +21,9 @@ type UserIndex struct {
|
||||
Name7 string `gorm:"index:type"`
|
||||
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
|
||||
|
||||
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
|
||||
CompName2 string `gorm:"index:,composite:idx_compname_1"`
|
||||
|
||||
// Composite Index: Flattened structure.
|
||||
Data0A string `gorm:"index:,composite:comp_id0"`
|
||||
Data0B string `gorm:"index:,composite:comp_id0"`
|
||||
@ -58,17 +61,17 @@ func TestParseIndex(t *testing.T) {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
|
||||
results := map[string]schema.Index{
|
||||
"idx_user_indices_name": {
|
||||
results := []*schema.Index{
|
||||
{
|
||||
Name: "idx_user_indices_name",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
|
||||
},
|
||||
"idx_name": {
|
||||
{
|
||||
Name: "idx_name",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
|
||||
},
|
||||
"idx_user_indices_name3": {
|
||||
{
|
||||
Name: "idx_user_indices_name3",
|
||||
Type: "btree",
|
||||
Where: "name3 != 'jinzhu'",
|
||||
@ -79,19 +82,19 @@ func TestParseIndex(t *testing.T) {
|
||||
Length: 10,
|
||||
}},
|
||||
},
|
||||
"idx_user_indices_name4": {
|
||||
{
|
||||
Name: "idx_user_indices_name4",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
|
||||
},
|
||||
"idx_user_indices_name5": {
|
||||
{
|
||||
Name: "idx_user_indices_name5",
|
||||
Class: "FULLTEXT",
|
||||
Comment: "hello , world",
|
||||
Where: "age > 10",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
|
||||
},
|
||||
"profile": {
|
||||
{
|
||||
Name: "profile",
|
||||
Comment: "hello , world",
|
||||
Where: "age > 10",
|
||||
@ -101,21 +104,21 @@ func TestParseIndex(t *testing.T) {
|
||||
Expression: "ABS(age)",
|
||||
}},
|
||||
},
|
||||
"idx_id": {
|
||||
{
|
||||
Name: "idx_id",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
},
|
||||
"idx_oid": {
|
||||
{
|
||||
Name: "idx_oid",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
|
||||
},
|
||||
"type": {
|
||||
{
|
||||
Name: "type",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
|
||||
},
|
||||
"idx_user_indices_name8": {
|
||||
{
|
||||
Name: "idx_user_indices_name8",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{
|
||||
@ -124,7 +127,16 @@ func TestParseIndex(t *testing.T) {
|
||||
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
|
||||
},
|
||||
},
|
||||
"idx_user_indices_comp_id0": {
|
||||
{
|
||||
Class: "UNIQUE",
|
||||
Name: "idx_user_indices_idx_compname_1",
|
||||
Option: "NULLS NOT DISTINCT",
|
||||
Fields: []schema.IndexOption{
|
||||
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
|
||||
{Field: &schema.Field{Name: "CompName2"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "idx_user_indices_comp_id0",
|
||||
Type: "",
|
||||
Fields: []schema.IndexOption{{
|
||||
@ -133,7 +145,7 @@ func TestParseIndex(t *testing.T) {
|
||||
Field: &schema.Field{Name: "Data0B"},
|
||||
}},
|
||||
},
|
||||
"idx_user_indices_comp_id1": {
|
||||
{
|
||||
Name: "idx_user_indices_comp_id1",
|
||||
Fields: []schema.IndexOption{{
|
||||
Field: &schema.Field{Name: "Data1A"},
|
||||
@ -143,7 +155,7 @@ func TestParseIndex(t *testing.T) {
|
||||
Field: &schema.Field{Name: "Data1C"},
|
||||
}},
|
||||
},
|
||||
"idx_user_indices_comp_id2": {
|
||||
{
|
||||
Name: "idx_user_indices_comp_id2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{
|
||||
@ -183,17 +195,17 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
t.Fatalf("failed to parse user index, got error %v", err)
|
||||
}
|
||||
indices := indexSchema.ParseIndexes()
|
||||
CheckIndices(t, map[string]schema.Index{
|
||||
"idx_index_tests_field_a": {
|
||||
expectedIndices := []*schema.Index{
|
||||
{
|
||||
Name: "idx_index_tests_field_a",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
|
||||
},
|
||||
"idx_index_tests_field_c": {
|
||||
{
|
||||
Name: "idx_index_tests_field_c",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
|
||||
},
|
||||
"idx_index_tests_field_d": {
|
||||
{
|
||||
Name: "idx_index_tests_field_d",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
@ -202,7 +214,7 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
{Field: &schema.Field{Name: "FieldD"}},
|
||||
},
|
||||
},
|
||||
"uniq_field_e1_e2": {
|
||||
{
|
||||
Name: "uniq_field_e1_e2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
@ -210,11 +222,7 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
{Field: &schema.Field{Name: "FieldE2"}},
|
||||
},
|
||||
},
|
||||
"idx_index_tests_field_f1": {
|
||||
Name: "idx_index_tests_field_f1",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
|
||||
},
|
||||
"uniq_field_f1_f2": {
|
||||
{
|
||||
Name: "uniq_field_f1_f2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
@ -222,12 +230,16 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
{Field: &schema.Field{Name: "FieldF2"}},
|
||||
},
|
||||
},
|
||||
"idx_index_tests_field_g": {
|
||||
{
|
||||
Name: "idx_index_tests_field_f1",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
|
||||
},
|
||||
{
|
||||
Name: "idx_index_tests_field_g",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
|
||||
},
|
||||
"uniq_field_h1_h2": {
|
||||
{
|
||||
Name: "uniq_field_h1_h2",
|
||||
Class: "UNIQUE",
|
||||
Fields: []schema.IndexOption{
|
||||
@ -235,30 +247,29 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
|
||||
{Field: &schema.Field{Name: "FieldH2"}},
|
||||
},
|
||||
},
|
||||
}, indices)
|
||||
}
|
||||
CheckIndices(t, expectedIndices, indices)
|
||||
}
|
||||
|
||||
func CheckIndices(t *testing.T, expected, actual map[string]schema.Index) {
|
||||
for k, ei := range expected {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
ai, ok := actual[k]
|
||||
if !ok {
|
||||
t.Errorf("expected index %q but actual missing", k)
|
||||
return
|
||||
}
|
||||
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
|
||||
if len(expected) != len(actual) {
|
||||
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
|
||||
return
|
||||
}
|
||||
|
||||
for i, ei := range expected {
|
||||
t.Run(ei.Name, func(t *testing.T) {
|
||||
ai := actual[i]
|
||||
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
|
||||
|
||||
if len(ei.Fields) != len(ai.Fields) {
|
||||
t.Errorf("expected index %q field length is %d but actual %d", k, len(ei.Fields), len(ai.Fields))
|
||||
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
|
||||
return
|
||||
}
|
||||
for i, ef := range ei.Fields {
|
||||
af := ai.Fields[i]
|
||||
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length")
|
||||
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
|
||||
}
|
||||
})
|
||||
delete(actual, k)
|
||||
}
|
||||
for k := range actual {
|
||||
t.Errorf("unexpected index %q", k)
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Namer namer interface
|
||||
@ -121,7 +123,7 @@ var (
|
||||
func init() {
|
||||
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
@ -186,9 +188,9 @@ func (ns NamingStrategy) toDBName(name string) string {
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
for _, initialism := range commonInitialisms {
|
||||
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -5,8 +5,12 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
@ -29,6 +33,8 @@ type Relationships struct {
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
|
||||
Mux sync.RWMutex
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
@ -72,7 +78,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
cacheStore := schema.cacheStore
|
||||
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -95,9 +101,10 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
}
|
||||
|
||||
if relation.Type == has {
|
||||
// don't add relations to embedded schema, which might be shared
|
||||
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
|
||||
relation.FieldSchema.Relationships.Mux.Lock()
|
||||
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
|
||||
relation.FieldSchema.Relationships.Mux.Unlock()
|
||||
}
|
||||
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
@ -150,12 +157,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.BindNames) <= 1 {
|
||||
if len(relation.Field.EmbeddedBindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.BindNames {
|
||||
if i < len(relation.Field.BindNames)-1 {
|
||||
for i, name := range relation.Field.EmbeddedBindNames {
|
||||
if i < len(relation.Field.EmbeddedBindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
@ -301,9 +308,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
for idx, ownField := range ownForeignFields {
|
||||
joinFieldName := strings.Title(schema.Name) + ownField.Name
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
|
||||
if len(joinForeignKeys) > idx {
|
||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
@ -318,7 +325,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
@ -329,7 +336,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
@ -347,7 +354,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: strings.Title(schema.Name) + field.Name,
|
||||
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
|
||||
Type: schema.ModelType,
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
@ -656,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
|
||||
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -668,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
|
||||
var (
|
||||
name string
|
||||
idx = strings.Index(str, ",")
|
||||
idx = strings.IndexByte(str, ',')
|
||||
settings = ParseTagSetting(str, ",")
|
||||
)
|
||||
|
||||
@ -755,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
|
||||
}
|
||||
|
||||
func copyableDataType(str DataType) bool {
|
||||
lowerStr := strings.ToLower(string(str))
|
||||
for _, s := range []string{"auto_increment", "primary key"} {
|
||||
if strings.Contains(strings.ToLower(string(str)), s) {
|
||||
if strings.Contains(lowerStr, s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -121,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBelongsToWithMixin(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
Refer string
|
||||
Name string
|
||||
}
|
||||
|
||||
type ProfileMixin struct {
|
||||
Profile Profile `gorm:"References:Refer"`
|
||||
ProfileRefer int
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
ProfileMixin
|
||||
}
|
||||
|
||||
checkStructRelation(t, &User{}, Relation{
|
||||
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
|
||||
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasOneOverrideForeignKey(t *testing.T) {
|
||||
type Profile struct {
|
||||
gorm.Model
|
||||
@ -776,6 +799,10 @@ func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
type NestedAddress struct {
|
||||
Address
|
||||
}
|
||||
type CountryMixin struct {
|
||||
CountryID int
|
||||
Country Country
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
@ -786,6 +813,7 @@ func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
Address
|
||||
}
|
||||
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
CountryMixin
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
@ -815,15 +843,11 @@ func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"NestedAddress": {
|
||||
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||
"Address": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -67,9 +68,10 @@ func (schema Schema) String() string {
|
||||
}
|
||||
|
||||
func (schema Schema) MakeSlice() reflect.Value {
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
|
||||
results := reflect.New(slice.Type())
|
||||
results.Elem().Set(slice)
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
@ -246,7 +248,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter()
|
||||
field.setupValuerAndSetter(modelType)
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.LookUpField("id")
|
||||
@ -312,8 +314,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
for _, cbName := range callbackTypes {
|
||||
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
case "func(*gorm.DB) error":
|
||||
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
|
||||
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
} else {
|
||||
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
|
||||
// PASS
|
||||
}
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||
}
|
||||
@ -337,7 +345,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
} else {
|
||||
|
@ -19,6 +19,22 @@ func TestParseSchema(t *testing.T) {
|
||||
checkUserSchema(t, user)
|
||||
}
|
||||
|
||||
func TestParseSchemaWithMap(t *testing.T) {
|
||||
type User struct {
|
||||
tests.User
|
||||
Attrs map[string]string `gorm:"type:Map(String,String);"`
|
||||
}
|
||||
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user with map, got error %v", err)
|
||||
}
|
||||
|
||||
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
|
||||
t.Errorf("failed to parse user field Attrs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSchemaWithPointerFields(t *testing.T) {
|
||||
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
|
@ -84,7 +84,10 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||
bytes, err = json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(bytes) > 0 {
|
||||
@ -126,12 +129,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
|
||||
appendToResults := func(value reflect.Value) {
|
||||
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
|
||||
|
73
statement.go
73
statement.go
@ -30,8 +30,9 @@ type Statement struct {
|
||||
Clauses map[string]clause.Clause
|
||||
BuildClauses []string
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
ColumnMapping map[string]string // map columns
|
||||
Joins []join
|
||||
Preloads map[string][]interface{}
|
||||
Settings sync.Map
|
||||
@ -46,15 +47,18 @@ type Statement struct {
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
scopes []func(*DB) *DB
|
||||
Result *result
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
JoinType clause.JoinType
|
||||
Name string
|
||||
Alias string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
Expression clause.Expression
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
@ -204,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case *DB:
|
||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if v.Statement.SQL.Len() > 0 {
|
||||
case interface{ getInstance() *DB }:
|
||||
cv := v.getInstance()
|
||||
|
||||
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if cv.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = v.Statement.SQL.String()
|
||||
sql = cv.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
cv.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
@ -320,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
curTable := stmt.Table
|
||||
if curTable == "" {
|
||||
curTable = clause.CurrentTable
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
@ -330,7 +341,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
if len(orConds.Exprs) == 1 {
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
}
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
@ -350,7 +363,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
if strings.Contains(key, ".") {
|
||||
column = clause.Column{Name: key}
|
||||
}
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
keys := make([]string, 0, len(v))
|
||||
@ -361,12 +378,16 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
for _, key := range keys {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
column := clause.Column{Name: key, Table: curTable}
|
||||
if strings.Contains(key, ".") {
|
||||
column = clause.Column{Name: key}
|
||||
}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
} else {
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
@ -375,10 +396,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -405,9 +426,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -419,9 +440,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -446,14 +467,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -513,12 +534,14 @@ func (stmt *Statement) clone() *Statement {
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
ColumnMapping: stmt.ColumnMapping,
|
||||
Preloads: map[string][]interface{}{},
|
||||
ConnPool: stmt.ConnPool,
|
||||
Schema: stmt.Schema,
|
||||
Context: stmt.Context,
|
||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||
SkipHooks: stmt.SkipHooks,
|
||||
Result: stmt.Result,
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
|
@ -554,3 +554,15 @@ func TestHasManyAssociationUnscoped(t *testing.T) {
|
||||
t.Errorf("expected %d contents, got %d", 0, len(contents))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
|
||||
user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
|
||||
t.Error("expected association error to be not nil")
|
||||
}
|
||||
}
|
||||
|
@ -255,3 +255,15 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
|
||||
DB.Model(&pets).Association("Toy").Clear()
|
||||
AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
|
||||
}
|
||||
|
||||
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
|
||||
user := User{Name: "jinzhu", Account: Account{Number: "1"}}
|
||||
|
||||
if err := DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
|
||||
t.Error("expected association error to be not nil")
|
||||
}
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
|
||||
results: []string{"c1", "c5", "c3", "c4"},
|
||||
results: []string{"c1", "c3", "c4", "c5"},
|
||||
},
|
||||
{
|
||||
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
|
||||
@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksGet(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
|
||||
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
|
||||
}
|
||||
|
||||
createCallback.Remove("c1")
|
||||
if cb := createCallback.Get("c2"); cb != nil {
|
||||
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksRemove(t *testing.T) {
|
||||
db, _ := gorm.Open(nil, nil)
|
||||
createCallback := db.Callback().Create()
|
||||
|
||||
createCallback.Before("*").Register("c1", c1)
|
||||
createCallback.After("*").Register("c2", c2)
|
||||
createCallback.Before("c4").Register("c3", c3)
|
||||
createCallback.After("c2").Register("c4", c4)
|
||||
|
||||
// callbacks: []string{"c1", "c3", "c4", "c2"}
|
||||
createCallback.Remove("c1")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c4")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c2")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
|
||||
createCallback.Remove("c3")
|
||||
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
|
||||
t.Errorf("callbacks tests failed, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
88
tests/check_subset_model_change_test.go
Normal file
88
tests/check_subset_model_change_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Man struct {
|
||||
ID int
|
||||
Age int
|
||||
Name string
|
||||
Detail string
|
||||
}
|
||||
|
||||
// Panic-safe BeforeUpdate hook that checks for Changed("age")
|
||||
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if !tx.Statement.Changed("age") {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Man) update(data interface{}) error {
|
||||
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
|
||||
}
|
||||
|
||||
func TestBeforeUpdateStatementChanged(t *testing.T) {
|
||||
DB.AutoMigrate(&Man{})
|
||||
type TestCase struct {
|
||||
BaseObjects Man
|
||||
change interface{}
|
||||
expectError bool
|
||||
}
|
||||
|
||||
testCases := []TestCase{
|
||||
{
|
||||
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Age int
|
||||
}{Age: 20},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Name string
|
||||
}{Name: "name-only"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
|
||||
change: struct {
|
||||
Name string
|
||||
Age int
|
||||
}{Name: "name-only", Age: 20},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
DB.Create(&test.BaseObjects)
|
||||
|
||||
// below comment is stored for future reference
|
||||
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
|
||||
err := test.BaseObjects.update(test.change)
|
||||
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
|
||||
if !test.expectError {
|
||||
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
|
||||
}
|
||||
} else {
|
||||
if test.expectError {
|
||||
t.Errorf("expected panic did not occur for input: %+v", test.change)
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("unexpected GORM error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,10 +1,8 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
mysql:
|
||||
image: 'mysql/mysql-server:latest'
|
||||
image: 'mysql:latest'
|
||||
ports:
|
||||
- "9910:3306"
|
||||
- "127.0.0.1:9910:3306"
|
||||
environment:
|
||||
- MYSQL_DATABASE=gorm
|
||||
- MYSQL_USER=gorm
|
||||
@ -13,24 +11,22 @@ services:
|
||||
postgres:
|
||||
image: 'postgres:latest'
|
||||
ports:
|
||||
- "9920:5432"
|
||||
- "127.0.0.1:9920:5432"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- POSTGRES_DB=gorm
|
||||
- POSTGRES_USER=gorm
|
||||
- POSTGRES_PASSWORD=gorm
|
||||
mssql:
|
||||
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
|
||||
image: '${MSSQL_IMAGE}:latest'
|
||||
ports:
|
||||
- "9930:1433"
|
||||
- "127.0.0.1:9930:1433"
|
||||
environment:
|
||||
- TZ=Asia/Shanghai
|
||||
- ACCEPT_EULA=Y
|
||||
- SA_PASSWORD=LoremIpsum86
|
||||
- MSSQL_DB=gorm
|
||||
- MSSQL_USER=gorm
|
||||
- MSSQL_PASSWORD=LoremIpsum86
|
||||
- MSSQL_SA_PASSWORD=LoremIpsum86
|
||||
tidb:
|
||||
image: 'pingcap/tidb:v6.5.0'
|
||||
ports:
|
||||
- "9940:4000"
|
||||
- "127.0.0.1:9940:4000"
|
||||
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &
|
@ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||
db.Logger = DB.Logger
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
@ -14,31 +14,48 @@ import (
|
||||
)
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
user := *GetUser("create", Config{})
|
||||
u1 := *GetUser("create", Config{})
|
||||
|
||||
if results := DB.Create(&user); results.Error != nil {
|
||||
if results := DB.Create(&u1); results.Error != nil {
|
||||
t.Fatalf("errors happened when create: %v", results.Error)
|
||||
} else if results.RowsAffected != 1 {
|
||||
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
|
||||
if u1.ID == 0 {
|
||||
t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
|
||||
}
|
||||
|
||||
if user.CreatedAt.IsZero() {
|
||||
if u1.CreatedAt.IsZero() {
|
||||
t.Errorf("user's created at should be not zero")
|
||||
}
|
||||
|
||||
if user.UpdatedAt.IsZero() {
|
||||
if u1.UpdatedAt.IsZero() {
|
||||
t.Errorf("user's updated at should be not zero")
|
||||
}
|
||||
|
||||
var newUser User
|
||||
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
|
||||
if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
|
||||
t.Fatalf("errors happened when query: %v", err)
|
||||
} else {
|
||||
CheckUser(t, newUser, user)
|
||||
CheckUser(t, newUser, u1)
|
||||
}
|
||||
|
||||
type user struct {
|
||||
ID int `gorm:"primaryKey;->:false"`
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
var u2 user
|
||||
if results := DB.Create(&u2); results.Error != nil {
|
||||
t.Fatalf("errors happened when create: %v", results.Error)
|
||||
} else if results.RowsAffected != 1 {
|
||||
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||
}
|
||||
|
||||
if u2.ID != 0 {
|
||||
t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -713,18 +730,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateFromMapWithTable(t *testing.T) {
|
||||
if !isMysql() {
|
||||
t.Skipf("This test case skipped, because of only supportting for mysql")
|
||||
}
|
||||
tableDB := DB.Table("`users`")
|
||||
tableDB := DB.Table("users")
|
||||
supportLastInsertID := isMysql() || isSqlite()
|
||||
|
||||
// case 1: create from map[string]interface{}
|
||||
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
|
||||
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
|
||||
if err := tableDB.Create(record).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map with table, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := record["@id"]; !ok {
|
||||
if _, ok := record["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
@ -733,8 +748,8 @@ func TestCreateFromMapWithTable(t *testing.T) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res["id"].(uint64)) != record["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
|
||||
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
|
||||
}
|
||||
|
||||
// case 2: create from *map[string]interface{}
|
||||
@ -743,7 +758,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
|
||||
if err := tableDB2.Create(&record1).Error; err != nil {
|
||||
t.Fatalf("failed to create data from map, got error: %v", err)
|
||||
}
|
||||
if _, ok := record1["@id"]; !ok {
|
||||
if _, ok := record1["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
@ -752,7 +767,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
|
||||
t.Fatalf("failed to create from map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res1["id"].(uint64)) != record1["@id"] {
|
||||
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
}
|
||||
|
||||
@ -767,11 +782,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
|
||||
t.Fatalf("failed to create data from slice of map, got error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := records[0]["@id"]; !ok {
|
||||
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
if _, ok := records[1]["@id"]; !ok {
|
||||
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
|
||||
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
|
||||
}
|
||||
|
||||
@ -785,11 +800,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
|
||||
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
|
||||
}
|
||||
|
||||
if int64(res2["id"].(uint64)) != records[0]["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
|
||||
}
|
||||
|
||||
if int64(res3["id"].(uint64)) != records[1]["@id"] {
|
||||
t.Fatal("failed to create data from map with table, @id != id")
|
||||
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
|
||||
t.Errorf("failed to create data from map with table, @id != id")
|
||||
}
|
||||
}
|
||||
|
@ -38,4 +38,22 @@ func TestDefaultValue(t *testing.T) {
|
||||
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
|
||||
t.Fatalf("Failed to find created data with default data, got %+v", result)
|
||||
}
|
||||
|
||||
type Harumph2 struct {
|
||||
ID int `gorm:"default:0"`
|
||||
Email string `gorm:"not null;index:,unique"`
|
||||
Name string `gorm:"notNull;default:foo"`
|
||||
Name2 string `gorm:"size:233;not null;default:'foo'"`
|
||||
Name3 string `gorm:"size:233;notNull;default:''"`
|
||||
Age int `gorm:"default:18"`
|
||||
Created time.Time `gorm:"default:2000-01-02"`
|
||||
Enabled bool `gorm:"default:true"`
|
||||
}
|
||||
|
||||
harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"}
|
||||
if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil {
|
||||
t.Fatalf("Failed to create data with default value, got error: %v", err)
|
||||
} else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" {
|
||||
t.Fatalf("Failed to create data with default value, got: %+v", harumph2)
|
||||
}
|
||||
}
|
||||
|
@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres, sqlserver support returning
|
||||
// only sqlite, postgres, gaussdb, sqlserver support returning
|
||||
func TestSoftDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -279,6 +279,6 @@ func TestEmbeddedTagSetting(t *testing.T) {
|
||||
err = DB.Save(&t1).Error
|
||||
AssertEqual(t, err, nil)
|
||||
if t1.Tag1.Id == 0 {
|
||||
t.Errorf("embedded struct's primary field should be rewrited")
|
||||
t.Errorf("embedded struct's primary field should be rewritten")
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
@ -81,7 +81,7 @@ func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
|
||||
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
|
||||
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
|
||||
return
|
||||
}
|
||||
|
248
tests/gaussdb_test.go
Normal file
248
tests/gaussdb_test.go
Normal file
@ -0,0 +1,248 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestGaussDBReturningIDWhichHasStringType(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Yasuo struct {
|
||||
// TODO: function gen_random_uuid() does not exist
|
||||
ID string `gorm:"default:gen_random_uuid()"`
|
||||
Name string
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Yasuo{})
|
||||
|
||||
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
yasuo := Yasuo{Name: "jinzhu"}
|
||||
if err := DB.Create(&yasuo).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
if yasuo.ID == "" {
|
||||
t.Fatal("should be able to has ID, but got zero value")
|
||||
}
|
||||
|
||||
var result Yasuo
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
yasuo.Name = "jinzhu1"
|
||||
if err := DB.Save(&yasuo).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDB(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Harumph struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"check:name_checker,name <> ''"`
|
||||
// TODO: function gen_random_uuid() does not exist
|
||||
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||
Things pq.StringArray `gorm:"type:text[]"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Harumph{})
|
||||
|
||||
if err := DB.AutoMigrate(&Harumph{}); err != nil {
|
||||
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||
}
|
||||
|
||||
harumph := Harumph{}
|
||||
if err := DB.Create(&harumph).Error; err == nil {
|
||||
t.Fatalf("should failed to create data, name can't be blank")
|
||||
}
|
||||
|
||||
harumph = Harumph{Name: "jinzhu"}
|
||||
if err := DB.Create(&harumph).Error; err != nil {
|
||||
t.Fatalf("should be able to create data, but got %v", err)
|
||||
}
|
||||
|
||||
var result Harumph
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
harumph.Name = "jinzhu1"
|
||||
if err := DB.Save(&harumph).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("log_usage")
|
||||
|
||||
if err := DB.Exec(`
|
||||
CREATE TABLE public.log_usage (
|
||||
log_id bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||
SEQUENCE NAME public.log_usage_log_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("failed to create table, got error %v", err)
|
||||
}
|
||||
|
||||
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get columns, got error %v", err)
|
||||
}
|
||||
|
||||
hasLogID := false
|
||||
for _, column := range columns {
|
||||
if column.Name() == "log_id" {
|
||||
hasLogID = true
|
||||
autoIncrement, ok := column.AutoIncrement()
|
||||
if !ok || !autoIncrement {
|
||||
t.Fatalf("column log_id should be auto incrementment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLogID {
|
||||
t.Fatalf("failed to found column log_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
|
||||
t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
|
||||
DB.AutoMigrate(&Post{}, &Category{})
|
||||
|
||||
post := Post{
|
||||
Title: "Hello World",
|
||||
Categories: []*Category{
|
||||
{Title: "Coding"},
|
||||
{Title: "Golang"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&post).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBOnConstraint(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Thing struct {
|
||||
gorm.Model
|
||||
SomeID string
|
||||
OtherID string
|
||||
Data string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Thing{})
|
||||
DB.Migrator().CreateTable(&Thing{})
|
||||
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
thing := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something",
|
||||
}
|
||||
|
||||
DB.Create(&thing)
|
||||
|
||||
thing2 := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something else",
|
||||
}
|
||||
|
||||
result := DB.Clauses(clause.OnConflict{
|
||||
OnConstraint: "some_id_other_id_unique",
|
||||
UpdateAll: true,
|
||||
}).Create(&thing2)
|
||||
if result.Error != nil {
|
||||
t.Errorf("creating second thing: %v", result.Error)
|
||||
}
|
||||
|
||||
var things []Thing
|
||||
if err := DB.Find(&things).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(things) > 1 {
|
||||
t.Errorf("expected 1 thing got more")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGaussDBAlterColumnDataType(t *testing.T) {
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&Company{})
|
||||
DB.AutoMigrate(Company{})
|
||||
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||
}
|
||||
|
||||
DB.AutoMigrate(Company{})
|
||||
}
|
875
tests/generics_test.go
Normal file
875
tests/generics_test.go
Normal file
@ -0,0 +1,875 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestGenericsCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := User{Name: "TestGenericsCreate", Age: 18}
|
||||
err := gorm.G[User](DB).Create(ctx, &user)
|
||||
if err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Fatalf("no primary key found for %v", user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.ID != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.ID != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != user.Name || u.Age != 0 {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != "" || u.Age != user.Age {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||
}
|
||||
|
||||
result := struct {
|
||||
ID int
|
||||
Name string
|
||||
}{}
|
||||
if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
|
||||
t.Fatalf("failed to scan user, got error: %v", err)
|
||||
} else if result.Name != user.Name || uint(result.ID) != user.ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", result, user)
|
||||
}
|
||||
|
||||
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
|
||||
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
|
||||
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsCreateInBatches(t *testing.T) {
|
||||
batch := []User{
|
||||
{Name: "GenericsCreateInBatches1"},
|
||||
{Name: "GenericsCreateInBatches2"},
|
||||
{Name: "GenericsCreateInBatches3"},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
for _, u := range batch {
|
||||
if u.ID == 0 {
|
||||
t.Fatalf("no primary key found for %v", u)
|
||||
}
|
||||
}
|
||||
|
||||
count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("expected 3 records, got %d", count)
|
||||
}
|
||||
|
||||
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||
if len(found) != len(batch) {
|
||||
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||
}
|
||||
|
||||
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
|
||||
if len(found) != 2 {
|
||||
t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
|
||||
}
|
||||
|
||||
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
|
||||
if len(found) != 1 {
|
||||
t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsExecAndUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
name := "GenericsExec"
|
||||
if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
|
||||
t.Fatalf("Exec insert failed: %v", err)
|
||||
}
|
||||
|
||||
u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if u.Name != name || u.ID == 0 {
|
||||
t.Errorf("found invalid user, got %v", u)
|
||||
}
|
||||
|
||||
name += "Update"
|
||||
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
|
||||
if rows != 1 {
|
||||
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||
}
|
||||
|
||||
nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if nu.Name != name || u.ID != nu.ID {
|
||||
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||
}
|
||||
|
||||
rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
|
||||
if rows != 1 {
|
||||
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||
}
|
||||
|
||||
nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to find user, got error: %v", err)
|
||||
} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
|
||||
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsRow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := User{Name: "GenericsRow"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
|
||||
var name string
|
||||
if err := row.Scan(&name); err != nil {
|
||||
t.Fatalf("Row scan failed: %v", err)
|
||||
}
|
||||
if name != user.Name {
|
||||
t.Errorf("expected %s, got %s", user.Name, name)
|
||||
}
|
||||
|
||||
user2 := User{Name: "GenericsRow2"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Rows failed: %v", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
t.Fatalf("rows.Scan failed: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 rows, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
u := User{Name: "GenericsDelete"}
|
||||
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
if rows != 1 {
|
||||
t.Errorf("expected 1 row deleted, got %d", rows)
|
||||
}
|
||||
|
||||
_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
t.Fatalf("User after delete failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsFindInBatches(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
users := []User{
|
||||
{Name: "GenericsFindBatchA"},
|
||||
{Name: "GenericsFindBatchB"},
|
||||
{Name: "GenericsFindBatchC"},
|
||||
{Name: "GenericsFindBatchD"},
|
||||
{Name: "GenericsFindBatchE"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
total := 0
|
||||
err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
|
||||
if len(chunk) > 2 {
|
||||
t.Errorf("batch size exceed 2: got %d", len(chunk))
|
||||
}
|
||||
|
||||
total += len(chunk)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("FindInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
if total != len(users) {
|
||||
t.Errorf("expected total %d, got %d", len(users), total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsScopes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
|
||||
err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
|
||||
if err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
filterName1 := func(stmt *gorm.Statement) {
|
||||
stmt.Where("name = ?", "GenericsScopes1")
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Scopes failed: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Name != "GenericsScopes1" {
|
||||
t.Fatalf("Scopes expected 1, got %d", len(results))
|
||||
}
|
||||
|
||||
notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
|
||||
if len(notResult) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||
} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
|
||||
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
|
||||
}
|
||||
|
||||
orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
|
||||
if len(orResult) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||
} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
|
||||
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsJoins(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
|
||||
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
// Inner JOIN + WHERE
|
||||
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Inner JOIN + WHERE with map
|
||||
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
db.Where(map[string]any{"name": u.Company.Name})
|
||||
return nil
|
||||
}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN w/o WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Left JOIN + Alias WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Joins failed: %v", err)
|
||||
}
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Raw Subquery JOIN + WHERE
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw subquery join failed: %v", err)
|
||||
}
|
||||
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
// Raw Subquery JOIN + WHERE + Select
|
||||
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
||||
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
if joinTable.Name != "t" {
|
||||
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||
}
|
||||
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||
return nil
|
||||
},
|
||||
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Raw subquery join failed: %v", err)
|
||||
}
|
||||
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
|
||||
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||
return errors.New("join error")
|
||||
}).First(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("Joins should got error, but got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedJoins(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "generics-nested-joins-1",
|
||||
Manager: &User{
|
||||
Name: "generics-nested-joins-manager-1",
|
||||
Company: Company{
|
||||
Name: "generics-nested-joins-manager-company-1",
|
||||
},
|
||||
NamedPet: &Pet{
|
||||
Name: "generics-nested-joins-manager-namepet-1",
|
||||
Toy: Toy{
|
||||
Name: "generics-nested-joins-manager-namepet-toy-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
|
||||
},
|
||||
{
|
||||
Name: "generics-nested-joins-2",
|
||||
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
db.CreateInBatches(ctx, &users, 100)
|
||||
|
||||
var userIDs []uint
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
|
||||
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
|
||||
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
|
||||
Where(map[string]any{"id": userIDs}).Find(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||
} else if len(users2) != len(users) {
|
||||
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||
}
|
||||
|
||||
sort.Slice(users2, func(i, j int) bool {
|
||||
return users2[i].ID > users2[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].ID > users[j].ID
|
||||
})
|
||||
|
||||
for idx, user := range users {
|
||||
// user
|
||||
CheckUser(t, user, users2[idx])
|
||||
if users2[idx].Manager == nil {
|
||||
t.Fatalf("Failed to load Manager")
|
||||
}
|
||||
// manager
|
||||
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||
// user pet
|
||||
if users2[idx].NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||
// manager pet
|
||||
if users2[idx].Manager.NamedPet == nil {
|
||||
t.Fatalf("Failed to load NamedPet")
|
||||
}
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsPreloads(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
|
||||
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
|
||||
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
|
||||
names := []string{u.Name, u2.Name, u3.Name}
|
||||
|
||||
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||
|
||||
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
|
||||
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||
}
|
||||
|
||||
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
db.Where("name = ?", u.Company.Name)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Preload failed: %v", err)
|
||||
}
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if result.Company.Name != u.Company.Name {
|
||||
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
} else if result.Company.Name != "" {
|
||||
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||
return errors.New("preload error")
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
if err == nil {
|
||||
t.Fatalf("Preload should failed, but got nil")
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "mysql" {
|
||||
// mysql 5.7 doesn't support row_number()
|
||||
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||
return
|
||||
}
|
||||
}
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name desc").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name < result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name").LimitPerRecord(5)
|
||||
return nil
|
||||
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
|
||||
db.Order("name")
|
||||
return nil
|
||||
}).Where("name in ?", names).Find(ctx)
|
||||
|
||||
for _, result := range results {
|
||||
if result.Name == u.Name {
|
||||
if len(result.Pets) != len(u.Pets) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
if len(result.Friends) != len(u.Friends) {
|
||||
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||
}
|
||||
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
|
||||
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(result.Pets); i++ {
|
||||
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsNestedPreloads(t *testing.T) {
|
||||
user := *GetUser("generics_nested_preload", Config{Pets: 2})
|
||||
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
|
||||
|
||||
ctx := context.Background()
|
||||
db := gorm.G[User](DB)
|
||||
|
||||
for idx, pet := range user.Pets {
|
||||
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
|
||||
}
|
||||
|
||||
if err := db.Create(ctx, &user); err != nil {
|
||||
t.Fatalf("errors happened when create: %v", err)
|
||||
}
|
||||
|
||||
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user2, user)
|
||||
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
|
||||
t.Fatalf("failed to nested preload")
|
||||
}
|
||||
|
||||
if DB.Dialector.Name() == "mysql" {
|
||||
// mysql 5.7 doesn't support row_number()
|
||||
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||
return
|
||||
}
|
||||
}
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver doesn't support order by in subquery
|
||||
return
|
||||
}
|
||||
|
||||
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||
db.LimitPerRecord(3)
|
||||
return nil
|
||||
}).Where(user.ID).Take(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to nested preload user")
|
||||
}
|
||||
CheckUser(t, user3, user)
|
||||
|
||||
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
|
||||
t.Errorf("failed to nested preload with limit per record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsDistinct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
batch := []User{
|
||||
{Name: "GenericsDistinctDup"},
|
||||
{Name: "GenericsDistinctDup"},
|
||||
{Name: "GenericsDistinctUnique"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Distinct Find failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Errorf("expected 2 distinct names, got %d", len(results))
|
||||
}
|
||||
|
||||
var names []string
|
||||
for _, u := range results {
|
||||
names = append(names, u.Name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
|
||||
if !reflect.DeepEqual(names, expected) {
|
||||
t.Errorf("expected names %v, got %v", expected, names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsGroupHaving(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
batch := []User{
|
||||
{Name: "GenericsGroupHavingMulti"},
|
||||
{Name: "GenericsGroupHavingMulti"},
|
||||
{Name: "GenericsGroupHavingSingle"},
|
||||
}
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Group+Having Find failed: %v", err)
|
||||
}
|
||||
|
||||
if len(grouped) != 1 {
|
||||
t.Errorf("expected 1 group with count>1, got %d", len(grouped))
|
||||
} else if grouped[0].Name != "GenericsGroupHavingMulti" {
|
||||
t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsSubQuery(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{
|
||||
{Name: "GenericsSubquery_1", Age: 10},
|
||||
{Name: "GenericsSubquery_2", Age: 20},
|
||||
{Name: "GenericsSubquery_3", Age: 30},
|
||||
{Name: "GenericsSubquery_4", Age: 40},
|
||||
}
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Errorf("Four users should be found, instead found %d", len(results))
|
||||
}
|
||||
|
||||
results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 3 {
|
||||
t.Errorf("Three users should be found, instead found %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsUpsert(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
}
|
||||
|
||||
lang3 := Language{Code: "upsert", Name: "Upsert"}
|
||||
if err := gorm.G[Language](DB, clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "code"}},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
|
||||
}).Create(ctx, &lang3); err != nil {
|
||||
t.Fatalf("failed to upsert, got %v", err)
|
||||
}
|
||||
|
||||
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
|
||||
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||
} else if len(langs) != 1 {
|
||||
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||
} else if langs[0].Name != "upsert-new" {
|
||||
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsWithResult(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
|
||||
|
||||
result := gorm.WithResult()
|
||||
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create users WithResult")
|
||||
}
|
||||
|
||||
if result.RowsAffected != 2 {
|
||||
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsReuse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
|
||||
|
||||
err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create users")
|
||||
}
|
||||
|
||||
reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
|
||||
|
||||
sg := sync.WaitGroup{}
|
||||
for i := 0; i < 5; i++ {
|
||||
sg.Add(1)
|
||||
|
||||
go func() {
|
||||
if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
|
||||
}
|
||||
|
||||
if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
|
||||
t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
|
||||
}
|
||||
|
||||
if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
|
||||
t.Errorf("failed to find user, got error: %v", err)
|
||||
} else if len(users) != 2 {
|
||||
t.Errorf("should find 2 users, but got %d", len(users))
|
||||
}
|
||||
sg.Done()
|
||||
}()
|
||||
}
|
||||
sg.Wait()
|
||||
}
|
||||
|
||||
func TestGenericsWithTransaction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
t.Fatalf("failed to begin transaction: %v", tx.Error)
|
||||
}
|
||||
|
||||
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
|
||||
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
|
||||
|
||||
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("expected 2 records, got %d", count)
|
||||
}
|
||||
|
||||
if err := tx.Rollback().Error; err != nil {
|
||||
t.Fatalf("failed to rollback transaction: %v", err)
|
||||
}
|
||||
|
||||
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||
if err != nil {
|
||||
t.Fatalf("Count failed: %v", err)
|
||||
}
|
||||
if count2 != 0 {
|
||||
t.Errorf("expected 0 records after rollback, got %d", count2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsToSQL(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
gorm.G[User](tx).Limit(10).Find(ctx)
|
||||
return tx
|
||||
})
|
||||
|
||||
if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
|
||||
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericsScanUUID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
users := []User{
|
||||
{Name: uuid.NewString(), Age: 21},
|
||||
{Name: uuid.NewString(), Age: 22},
|
||||
{Name: uuid.NewString(), Age: 23},
|
||||
}
|
||||
|
||||
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil {
|
||||
t.Fatalf("CreateInBatches failed: %v", err)
|
||||
}
|
||||
|
||||
userIds := []uuid.UUID{}
|
||||
if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 {
|
||||
t.Fatalf("Scan failed: %v, userids %v", err, userIds)
|
||||
}
|
||||
|
||||
if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name {
|
||||
t.Fatalf("wrong uuid scanned")
|
||||
}
|
||||
}
|
38
tests/go.mod
38
tests/go.mod
@ -1,38 +1,40 @@
|
||||
module gorm.io/gorm/tests
|
||||
|
||||
go 1.18
|
||||
go 1.23.0
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/stretchr/testify v1.8.4
|
||||
gorm.io/driver/mysql v1.5.4
|
||||
gorm.io/driver/postgres v1.5.6
|
||||
gorm.io/driver/sqlite v1.5.5
|
||||
gorm.io/driver/sqlserver v1.5.3
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||
github.com/stretchr/testify v1.10.0
|
||||
gorm.io/driver/gaussdb v0.1.0
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/driver/sqlserver v1.6.1
|
||||
gorm.io/gorm v1.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.3 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.7.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.6.0 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.9.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
github.com/tjfoc/gmsm v1.4.1 // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
||||
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
|
||||
|
@ -281,6 +281,10 @@ func isMysql() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "mysql"
|
||||
}
|
||||
|
||||
func isSqlite() bool {
|
||||
return os.Getenv("GORM_DIALECT") == "sqlite"
|
||||
}
|
||||
|
||||
func db(unscoped bool) *gorm.DB {
|
||||
if unscoped {
|
||||
return DB.Unscoped()
|
||||
|
@ -2,6 +2,8 @@ package tests_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -566,3 +568,44 @@ func TestUpdateCallbacks(t *testing.T) {
|
||||
t.Fatalf("before update should not be called")
|
||||
}
|
||||
}
|
||||
|
||||
type Product6 struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Item *ProductItem2
|
||||
}
|
||||
|
||||
type ProductItem2 struct {
|
||||
gorm.Model
|
||||
Product6ID uint
|
||||
}
|
||||
|
||||
func (p *Product6) BeforeDelete(tx *gorm.DB) error {
|
||||
if err := tx.Delete(&p.Item).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPropagateUnscoped(t *testing.T) {
|
||||
_DB, err := OpenTestConnection(&gorm.Config{
|
||||
PropagateUnscoped: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("failed to connect database, got error %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
_DB.Migrator().DropTable(&Product6{}, &ProductItem2{})
|
||||
_DB.AutoMigrate(&Product6{}, &ProductItem2{})
|
||||
|
||||
p := Product6{
|
||||
Name: "unique_code",
|
||||
Item: &ProductItem2{},
|
||||
}
|
||||
_DB.Model(&Product6{}).Create(&p)
|
||||
|
||||
if err := _DB.Unscoped().Delete(&p).Error; err != nil {
|
||||
t.Fatalf("unscoped did not propagate")
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,12 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
@ -184,14 +186,12 @@ func TestJoinCount(t *testing.T) {
|
||||
DB.Create(&user)
|
||||
|
||||
query := DB.Model(&User{}).Joins("Company")
|
||||
// Bug happens when .Count is called on a query.
|
||||
// Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass.
|
||||
|
||||
var total int64
|
||||
query.Count(&total)
|
||||
|
||||
var result User
|
||||
|
||||
// Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id
|
||||
if err := query.First(&result, user.ID).Error; err != nil {
|
||||
t.Fatalf("Failed, got error: %v", err)
|
||||
}
|
||||
@ -199,6 +199,10 @@ func TestJoinCount(t *testing.T) {
|
||||
if result.ID != user.ID {
|
||||
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
|
||||
}
|
||||
// should find company
|
||||
if result.Company.ID != *user.CompanyID {
|
||||
t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinWithSoftDeleted(t *testing.T) {
|
||||
@ -400,3 +404,75 @@ func TestNestedJoins(t *testing.T) {
|
||||
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinsPreload_Issue7013(t *testing.T) {
|
||||
manager := &User{Name: "Manager"}
|
||||
DB.Create(manager)
|
||||
|
||||
var userIDs []uint
|
||||
for i := 0; i < 21; i++ {
|
||||
user := &User{Name: fmt.Sprintf("User%d", i), ManagerID: &manager.ID}
|
||||
DB.Create(user)
|
||||
userIDs = append(userIDs, user.ID)
|
||||
}
|
||||
|
||||
var entries []User
|
||||
assert.NotPanics(t, func() {
|
||||
assert.NoError(t,
|
||||
DB.Preload("Manager.Team").
|
||||
Joins("Manager.Company").
|
||||
Find(&entries).Error)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) {
|
||||
type (
|
||||
Furniture struct {
|
||||
gorm.Model
|
||||
OwnerID *uint
|
||||
}
|
||||
|
||||
Owner struct {
|
||||
gorm.Model
|
||||
Furnitures []Furniture
|
||||
CompanyID *uint
|
||||
Company Company
|
||||
}
|
||||
|
||||
Building struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
OwnerID *uint
|
||||
Owner Owner
|
||||
}
|
||||
)
|
||||
|
||||
DB.Migrator().DropTable(&Building{}, &Owner{}, &Furniture{})
|
||||
DB.Migrator().AutoMigrate(&Building{}, &Owner{}, &Furniture{})
|
||||
|
||||
home := &Building{Name: "relation_empty"}
|
||||
DB.Create(home)
|
||||
|
||||
var entries []Building
|
||||
assert.NotPanics(t, func() {
|
||||
assert.NoError(t,
|
||||
DB.Preload("Owner.Furnitures").
|
||||
Joins("Owner.Company").
|
||||
Find(&entries).Error)
|
||||
})
|
||||
|
||||
AssertEqual(t, entries, []Building{{Model: home.Model, Name: "relation_empty", Owner: Owner{Company: Company{}}}})
|
||||
}
|
||||
|
||||
func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
|
||||
var entries []User
|
||||
assert.NotPanics(t, func() {
|
||||
assert.NoError(t,
|
||||
DB.Preload("Manager.Team").
|
||||
Joins("Manager.Company").
|
||||
Where("1 <> 1").
|
||||
Find(&entries).Error)
|
||||
})
|
||||
|
||||
AssertEqual(t, len(entries), 0)
|
||||
}
|
||||
|
529
tests/lru_test.go
Normal file
529
tests/lru_test.go
Normal file
@ -0,0 +1,529 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"gorm.io/gorm/internal/lru"
|
||||
"math"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](10, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key1", 2)
|
||||
|
||||
if value, ok := lru.Get("key1"); !ok || value != 2 {
|
||||
t.Errorf("Expected value to be updated to 2, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_NewKey_AddsEntry(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](10, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
|
||||
if value, ok := lru.Get("key1"); !ok || value != 1 {
|
||||
t.Errorf("Expected key1 to be added with value 1, got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](2, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
|
||||
if _, ok := lru.Get("key1"); ok {
|
||||
t.Errorf("Expected key1 to be removed, but it still exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](0, nil, time.Hour)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
|
||||
if _, ok := lru.Get("key1"); !ok {
|
||||
t.Errorf("Expected key1 to exist, but it was evicted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Add_Eviction(t *testing.T) {
|
||||
lru := lru.NewLRU[string, int](0, nil, time.Second*2)
|
||||
lru.Add("key1", 1)
|
||||
lru.Add("key2", 2)
|
||||
lru.Add("key3", 3)
|
||||
time.Sleep(time.Second * 3)
|
||||
if lru.Cap() != 0 {
|
||||
t.Errorf("Expected lru to be empty, but it was not")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Rand_NoExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, 0)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
var hit, miss int
|
||||
for i := 0; i < 2*b.N; i++ {
|
||||
if i%2 == 0 {
|
||||
l.Add(trace[i], trace[i])
|
||||
} else {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Freq_NoExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, 0)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
if i%2 == 0 {
|
||||
trace[i] = getRand(b) % 16384
|
||||
} else {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Add(trace[i], trace[i])
|
||||
}
|
||||
var hit, miss int
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Rand_WithExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
var hit, miss int
|
||||
for i := 0; i < 2*b.N; i++ {
|
||||
if i%2 == 0 {
|
||||
l.Add(trace[i], trace[i])
|
||||
} else {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func BenchmarkLRU_Freq_WithExpire(b *testing.B) {
|
||||
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
|
||||
|
||||
trace := make([]int64, b.N*2)
|
||||
for i := 0; i < b.N*2; i++ {
|
||||
if i%2 == 0 {
|
||||
trace[i] = getRand(b) % 16384
|
||||
} else {
|
||||
trace[i] = getRand(b) % 32768
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
l.Add(trace[i], trace[i])
|
||||
}
|
||||
var hit, miss int
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, ok := l.Get(trace[i]); ok {
|
||||
hit++
|
||||
} else {
|
||||
miss++
|
||||
}
|
||||
}
|
||||
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
|
||||
}
|
||||
|
||||
func TestLRUNoPurge(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](10, nil, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok := lc.Peek("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
if !lc.Contains("key1") {
|
||||
t.Fatalf("should contain key1")
|
||||
}
|
||||
if lc.Contains("key2") {
|
||||
t.Fatalf("should not contain key2")
|
||||
}
|
||||
|
||||
v, ok = lc.Peek("key2")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
if lc.Resize(0) != 0 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
if lc.Resize(2) != 0 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
lc.Add("key2", "val2")
|
||||
if lc.Resize(1) != 1 {
|
||||
t.Fatalf("evicted count differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUEdgeCases(t *testing.T) {
|
||||
lc := lru.NewLRU[string, *string](2, nil, 0)
|
||||
|
||||
// Adding a nil value
|
||||
lc.Add("key1", nil)
|
||||
|
||||
value, exists := lc.Get("key1")
|
||||
if value != nil || !exists {
|
||||
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
|
||||
}
|
||||
|
||||
// Adding an entry with the same key but different value
|
||||
newVal := "val1"
|
||||
lc.Add("key1", &newVal)
|
||||
|
||||
value, exists = lc.Get("key1")
|
||||
if value != &newVal || !exists {
|
||||
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRU_Values(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](3, nil, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
lc.Add("key2", "val2")
|
||||
lc.Add("key3", "val3")
|
||||
|
||||
values := lc.Values()
|
||||
if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) {
|
||||
t.Fatalf("values differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestExpirableMultipleClose(_ *testing.T) {
|
||||
// lc :=lru.NewLRU[string, string](10, nil, 0)
|
||||
// lc.Close()
|
||||
// // should not panic
|
||||
// lc.Close()
|
||||
// }
|
||||
|
||||
func TestLRUWithPurge(t *testing.T) {
|
||||
var evicted []string
|
||||
lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond)
|
||||
|
||||
k, v, ok := lc.GetOldest()
|
||||
if k != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // not enough to expire
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond) // expire
|
||||
v, ok = lc.Get("key1")
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be nil")
|
||||
}
|
||||
|
||||
if lc.Len() != 0 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
// add new entry
|
||||
lc.Add("key2", "val2")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
k, v, ok = lc.GetOldest()
|
||||
if k != "key2" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if v != "val2" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestLRUWithPurgeEnforcedBySize(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](10, nil, time.Hour)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
i := i
|
||||
lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i))
|
||||
v, ok := lc.Get(fmt.Sprintf("key%d", i))
|
||||
if v != fmt.Sprintf("val%d", i) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if lc.Len() > 20 {
|
||||
t.Fatalf("length should be less than 20")
|
||||
}
|
||||
}
|
||||
|
||||
if lc.Len() != 10 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUConcurrency(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](0, nil, 0)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
go func(i int) {
|
||||
lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10))
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
if lc.Len() != 100 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUInvalidateAndEvict(t *testing.T) {
|
||||
var evicted int
|
||||
lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
lc.Add("key2", "val2")
|
||||
|
||||
val, ok := lc.Get("key1")
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if val != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if evicted != 0 {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
lc.Remove("key1")
|
||||
if evicted != 1 {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
val, ok = lc.Get("key1")
|
||||
if val != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadingExpired(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5)
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok := lc.Peek("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
for {
|
||||
result, ok := lc.Get("key1")
|
||||
if ok && result == "" {
|
||||
t.Fatalf("ok should return a result")
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 100) // wait for expiration reaper
|
||||
if lc.Len() != 0 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Peek("key1")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRURemoveOldest(t *testing.T) {
|
||||
lc := lru.NewLRU[string, string](2, nil, 0)
|
||||
|
||||
if lc.Cap() != 2 {
|
||||
t.Fatalf("expect cap is 2")
|
||||
}
|
||||
|
||||
k, v, ok := lc.RemoveOldest()
|
||||
if k != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if v != "" {
|
||||
t.Fatalf("should be empty")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
ok = lc.Remove("non_existent")
|
||||
if ok {
|
||||
t.Fatalf("should be false")
|
||||
}
|
||||
|
||||
lc.Add("key1", "val1")
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
v, ok = lc.Get("key1")
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
lc.Add("key2", "val2")
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 2 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
|
||||
k, v, ok = lc.RemoveOldest()
|
||||
if k != "key1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if v != "val1" {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("should be true")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) {
|
||||
t.Fatalf("value differs from expected")
|
||||
}
|
||||
if lc.Len() != 1 {
|
||||
t.Fatalf("length differs from expected")
|
||||
}
|
||||
}
|
||||
|
||||
func getRand(tb testing.TB) int64 {
|
||||
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return out.Int64()
|
||||
}
|
@ -5,18 +5,18 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/gaussdb"
|
||||
"gorm.io/driver/postgres"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
@ -83,8 +83,8 @@ func TestMigrate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateInt8PG(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -140,8 +140,137 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateNullable(t *testing.T) {
|
||||
type MigrateNullableColumn struct {
|
||||
ID uint
|
||||
Bonus float64 `gorm:"not null"`
|
||||
Stock float64
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&MigrateNullableColumn{})
|
||||
|
||||
DB.AutoMigrate(&MigrateNullableColumn{})
|
||||
|
||||
type MigrateNullableColumn2 struct {
|
||||
ID uint
|
||||
Bonus float64
|
||||
Stock float64 `gorm:"not null"`
|
||||
}
|
||||
|
||||
if err := DB.Table("migrate_nullable_columns").AutoMigrate(&MigrateNullableColumn2{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err := DB.Table("migrate_nullable_columns").Migrator().ColumnTypes(&MigrateNullableColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "bonus":
|
||||
// allow to change non-nullable to nullable
|
||||
if nullable, _ := columnType.Nullable(); !nullable {
|
||||
t.Fatalf("bonus's nullable should be true, bug got %t", nullable)
|
||||
}
|
||||
case "stock":
|
||||
// do not allow to change nullable to non-nullable
|
||||
if nullable, _ := columnType.Nullable(); !nullable {
|
||||
t.Fatalf("stock's nullable should be true, bug got %t", nullable)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmartMigrateColumn(t *testing.T) {
|
||||
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
|
||||
fullSupported := map[string]bool{"mysql": true, "postgres": true, "gaussdb": true}[DB.Dialector.Name()]
|
||||
|
||||
type UserMigrateColumn struct {
|
||||
ID uint
|
||||
Name string
|
||||
Salary float64
|
||||
Birthday time.Time `gorm:"precision:4"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&UserMigrateColumn{})
|
||||
|
||||
DB.AutoMigrate(&UserMigrateColumn{})
|
||||
|
||||
type UserMigrateColumn2 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:128"`
|
||||
Salary float64 `gorm:"precision:2"`
|
||||
Birthday time.Time `gorm:"precision:2"`
|
||||
NameIgnoreMigration string `gorm:"size:100"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UserMigrateColumn3 struct {
|
||||
ID uint
|
||||
Name string `gorm:"size:256"`
|
||||
Salary float64 `gorm:"precision:3"`
|
||||
Birthday time.Time `gorm:"precision:3"`
|
||||
NameIgnoreMigration string `gorm:"size:128;-:migration"`
|
||||
}
|
||||
|
||||
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
|
||||
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
switch columnType.Name() {
|
||||
case "name":
|
||||
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
|
||||
t.Fatalf("name's length should be 128, but got %v", length)
|
||||
}
|
||||
case "salary":
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
|
||||
t.Fatalf("salary's precision should be 2, but got %v", precision)
|
||||
}
|
||||
case "birthday":
|
||||
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
|
||||
t.Fatalf("birthday's precision should be 2, but got %v", precision)
|
||||
}
|
||||
case "name_ignore_migration":
|
||||
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 {
|
||||
t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSmartMigrateColumnGaussDB(t *testing.T) {
|
||||
fullSupported := map[string]bool{"mysql": true, "gaussdb": true}[DB.Dialector.Name()]
|
||||
|
||||
type UserMigrateColumn struct {
|
||||
ID uint
|
||||
@ -809,7 +938,7 @@ func TestMigrateColumnOrder(t *testing.T) {
|
||||
|
||||
// https://github.com/go-gorm/gorm/issues/5047
|
||||
func TestMigrateSerialColumn(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -968,6 +1097,42 @@ func TestPrimarykeyID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimarykeyIDGaussDB(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support uuid-ossp plugin (SQLSTATE 58P01)")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
type MissPKLanguage struct {
|
||||
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
|
||||
Name string
|
||||
}
|
||||
|
||||
type MissPKUser struct {
|
||||
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
|
||||
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
|
||||
}
|
||||
|
||||
var err error
|
||||
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
|
||||
if err != nil {
|
||||
t.Fatalf("DropTable err:%v", err)
|
||||
}
|
||||
// TODO: ERROR: could not open extension control file: No such file or directory (SQLSTATE 58P01)
|
||||
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
|
||||
|
||||
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
|
||||
if err != nil {
|
||||
t.Fatalf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
// patch
|
||||
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
|
||||
if err != nil {
|
||||
t.Fatalf("AutoMigrate err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCurrentTimestamp(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
return
|
||||
@ -1168,35 +1333,24 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
// TODO: ERROR: must have at least one column (SQLSTATE 0A000)
|
||||
func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
|
||||
t.Skipf("This test case skipped, because of gaussdb not support creaing empty table(SQLSTATE 0A000)")
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
|
||||
db, err := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Errorf("Open err:%v", err)
|
||||
}
|
||||
if debug := os.Getenv("DEBUG"); debug == "true" {
|
||||
db.Logger = db.Logger.LogMode(logger.Info)
|
||||
} else if debug == "false" {
|
||||
db.Logger = db.Logger.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
type Object1 struct {
|
||||
ID uint
|
||||
}
|
||||
type Object1 struct{}
|
||||
type Object2 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int8"`
|
||||
Field1 string
|
||||
}
|
||||
type Object3 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int4"`
|
||||
}
|
||||
type Object4 struct {
|
||||
ID uint
|
||||
Field2 int
|
||||
Field2 string
|
||||
}
|
||||
db.Migrator().DropTable("objects")
|
||||
|
||||
@ -1204,63 +1358,16 @@ func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
err = db.Table("objects").Create(&Object1{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("create err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object2{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object2{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AlterColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object3{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object3{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object4{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
|
||||
@ -1303,7 +1410,7 @@ func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMigrateArrayTypeModel(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -1420,7 +1527,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) {
|
||||
AssertEqual(t, nil, err)
|
||||
}
|
||||
|
||||
func TestMigrateDefaultNullString(t *testing.T) {
|
||||
func TestMigrateWithDefaultValue(t *testing.T) {
|
||||
if DB.Dialector.Name() == "sqlserver" {
|
||||
// sqlserver driver treats NULL and 'NULL' the same
|
||||
t.Skip("skip sqlserver")
|
||||
@ -1434,6 +1541,7 @@ func TestMigrateDefaultNullString(t *testing.T) {
|
||||
type NullStringModel struct {
|
||||
ID uint
|
||||
Content string `gorm:"default:'null'"`
|
||||
Active bool `gorm:"default:false"`
|
||||
}
|
||||
|
||||
tableName := "null_string_model"
|
||||
@ -1454,6 +1562,14 @@ func TestMigrateDefaultNullString(t *testing.T) {
|
||||
AssertEqual(t, defVal, "null")
|
||||
AssertEqual(t, ok, true)
|
||||
|
||||
columnType2, err := findColumnType(tableName, "active")
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
defVal, ok = columnType2.DefaultValue()
|
||||
bv, _ := strconv.ParseBool(defVal)
|
||||
AssertEqual(t, bv, false)
|
||||
AssertEqual(t, ok, true)
|
||||
|
||||
// default 'null' -> 'null'
|
||||
session := DB.Session(&gorm.Session{Logger: Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
@ -1616,8 +1732,8 @@ func TestMigrateView(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateExistingBoolColumnPG(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -1935,3 +2051,114 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAutoMigrateDecimal(t *testing.T, model1, model2 any) []string {
|
||||
tracer := Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE ") {
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if decimal is not change: sql: %s", sql)
|
||||
}
|
||||
},
|
||||
}
|
||||
session := DB.Session(&gorm.Session{Logger: tracer})
|
||||
|
||||
DB.Migrator().DropTable(model1)
|
||||
var modifySql []string
|
||||
if err := session.AutoMigrate(model1); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
if err := session.AutoMigrate(model1); err != nil {
|
||||
t.Fatalf("failed to auto migrate, got error: %v", err)
|
||||
}
|
||||
tracer2 := Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
modifySql = append(modifySql, sql)
|
||||
},
|
||||
}
|
||||
session2 := DB.Session(&gorm.Session{Logger: tracer2})
|
||||
err := session2.Table("migrate_decimal_columns").Migrator().AutoMigrate(model2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get column types, got error: %v", err)
|
||||
}
|
||||
return modifySql
|
||||
}
|
||||
|
||||
func decimalColumnsTest[T, T2 any](t *testing.T, expectedSql []string) {
|
||||
var t1 T
|
||||
var t2 T2
|
||||
modSql := testAutoMigrateDecimal(t, t1, t2)
|
||||
var alterSQL []string
|
||||
for _, sql := range modSql {
|
||||
if strings.HasPrefix(sql, "ALTER TABLE ") {
|
||||
alterSQL = append(alterSQL, sql)
|
||||
}
|
||||
}
|
||||
|
||||
if len(alterSQL) != 3 {
|
||||
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
|
||||
}
|
||||
for i := range alterSQL {
|
||||
if alterSQL[i] != expectedSql[i] {
|
||||
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateDecimal(t *testing.T) {
|
||||
if DB.Dialector.Name() == "sqlserver" { // database/sql will replace numeric to decimal. so only support decimal.
|
||||
type MigrateDecimalColumn struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
|
||||
}
|
||||
type MigrateDecimalColumn2 struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
|
||||
}
|
||||
expectedSql := []string{
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" decimal(8) NOT NULL`,
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" decimal(9,1) NOT NULL`,
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" decimal(9,2) NOT NULL`,
|
||||
}
|
||||
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
|
||||
} else if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
|
||||
type MigrateDecimalColumn struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:numeric(9,0);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:numeric(8);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:numeric(8,1);not null" json:"recid3"`
|
||||
}
|
||||
type MigrateDecimalColumn2 struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:numeric(8);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:numeric(9,1);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:numeric(9,2);not null" json:"recid3"`
|
||||
}
|
||||
expectedSql := []string{
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" TYPE numeric(8) USING "recid1"::numeric(8)`,
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" TYPE numeric(9,1) USING "recid2"::numeric(9,1)`,
|
||||
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" TYPE numeric(9,2) USING "recid3"::numeric(9,2)`,
|
||||
}
|
||||
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
|
||||
} else if DB.Dialector.Name() == "mysql" {
|
||||
type MigrateDecimalColumn struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
|
||||
}
|
||||
type MigrateDecimalColumn2 struct {
|
||||
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
|
||||
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
|
||||
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
|
||||
}
|
||||
expectedSql := []string{
|
||||
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid1` decimal(8) NOT NULL",
|
||||
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid2` decimal(9,1) NOT NULL",
|
||||
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid3` decimal(9,2) NOT NULL",
|
||||
}
|
||||
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
|
||||
}
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
if name := DB.Dialector.Name(); name == "postgres" {
|
||||
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" || name == "gaussdb" {
|
||||
stmt := gorm.Statement{DB: DB}
|
||||
stmt.Parse(&Blog{})
|
||||
stmt.Schema.LookUpField("ID").Unique = true
|
||||
@ -142,6 +142,9 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
|
||||
if name := DB.Dialector.Name(); name == "postgres" {
|
||||
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
|
||||
}
|
||||
if name := DB.Dialector.Name(); name == "gaussdb" {
|
||||
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
|
||||
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
|
||||
@ -264,10 +267,14 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
if name := DB.Dialector.Name(); name == "postgres" {
|
||||
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" {
|
||||
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
|
||||
}
|
||||
|
||||
if name := DB.Dialector.Name(); name == "gaussdb" {
|
||||
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
|
||||
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate, got error: %v", err)
|
||||
@ -332,7 +339,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
||||
|
||||
DB.Model(&blog2).Association("LocaleTags").Find(&tags)
|
||||
if !compareTags(tags, []string{"tag4"}) {
|
||||
t.Fatalf("Should find 1 tags for EN Blog")
|
||||
t.Fatalf("Should find 1 tags for EN Blog, but got %v", tags)
|
||||
}
|
||||
|
||||
// Replace
|
||||
|
@ -37,7 +37,7 @@ func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||
}
|
||||
|
||||
animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone)
|
||||
DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
|
||||
DB.Save(&animal).Update("From", "a nice place") // The name field should be untouched
|
||||
DB.First(&animal, animal.Counter)
|
||||
if animal.Name != "galeone" {
|
||||
t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)
|
||||
|
@ -696,6 +696,10 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
|
||||
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
|
||||
}
|
||||
|
||||
if name := DB.Dialector.Name(); name == "mysql" {
|
||||
t.Skip("skip mysql due to it only allow unique constraint matching given keys")
|
||||
}
|
||||
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint `gorm:"primary_key;"`
|
||||
|
@ -1,14 +1,14 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -337,7 +337,7 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
|
||||
value := Value{
|
||||
value1 := Value{
|
||||
Name: "value",
|
||||
Nested: Nested{
|
||||
Preloads: []*Preload{
|
||||
@ -346,32 +346,150 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
|
||||
Join: Join{Value: "j1"},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&value).Error; err != nil {
|
||||
value2 := Value{
|
||||
Name: "value2",
|
||||
Nested: Nested{
|
||||
Preloads: []*Preload{
|
||||
{Value: "p3"}, {Value: "p4"}, {Value: "p5"},
|
||||
},
|
||||
Join: Join{Value: "j2"},
|
||||
},
|
||||
}
|
||||
|
||||
values := []*Value{&value1, &value2}
|
||||
if err := DB.Create(&values).Error; err != nil {
|
||||
t.Errorf("failed to create value, got err: %v", err)
|
||||
}
|
||||
|
||||
var find1 Value
|
||||
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
|
||||
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find1, value)
|
||||
AssertEqual(t, find1, value1)
|
||||
|
||||
var find2 Value
|
||||
// Joins will automatically add Nested queries.
|
||||
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
|
||||
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find2, value)
|
||||
AssertEqual(t, find2, value2)
|
||||
|
||||
var finds []Value
|
||||
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
require.Len(t, finds, 1)
|
||||
AssertEqual(t, finds[0], value)
|
||||
AssertEqual(t, len(finds), 2)
|
||||
AssertEqual(t, finds[0], value1)
|
||||
AssertEqual(t, finds[1], value2)
|
||||
}
|
||||
|
||||
func TestMergeNestedPreloadWithNestedJoin(t *testing.T) {
|
||||
users := []User{
|
||||
{
|
||||
Name: "TestMergeNestedPreloadWithNestedJoin-1",
|
||||
Manager: &User{
|
||||
Name: "Alexis Manager",
|
||||
Tools: []Tools{
|
||||
{Name: "Alexis Tool 1"},
|
||||
{Name: "Alexis Tool 2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "TestMergeNestedPreloadWithNestedJoin-2",
|
||||
Manager: &User{
|
||||
Name: "Jinzhu Manager",
|
||||
Tools: []Tools{
|
||||
{Name: "Jinzhu Tool 1"},
|
||||
{Name: "Jinzhu Tool 2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Create(&users)
|
||||
|
||||
query := make([]string, 0)
|
||||
sess := DB.Session(&gorm.Session{Logger: Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
query = append(query, sql)
|
||||
},
|
||||
}})
|
||||
|
||||
var result []User
|
||||
err := sess.
|
||||
Joins("Manager").
|
||||
Preload("Manager.Tools").
|
||||
Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%").
|
||||
Find(&result).Error
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to preload and find users: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, users)
|
||||
AssertEqual(t, len(query), 2) // Check preload queries are merged
|
||||
|
||||
if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) {
|
||||
t.Fatalf("Expected first query to preload manager tools, got: %s", query[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreloadWithPointerJoin(t *testing.T) {
|
||||
type (
|
||||
Preload struct {
|
||||
ID uint
|
||||
Value string
|
||||
JoinID uint
|
||||
}
|
||||
Join struct {
|
||||
ID uint
|
||||
Value string
|
||||
Preload Preload
|
||||
NestedID uint
|
||||
}
|
||||
Nested struct {
|
||||
ID uint
|
||||
Join Join
|
||||
ValueID uint
|
||||
}
|
||||
Value struct {
|
||||
ID uint
|
||||
Name string
|
||||
Nested *Nested
|
||||
}
|
||||
)
|
||||
|
||||
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
|
||||
|
||||
value := Value{
|
||||
Name: "value",
|
||||
Nested: &Nested{
|
||||
Join: Join{
|
||||
Value: "j1",
|
||||
Preload: Preload{
|
||||
Value: "p1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&value).Error; err != nil {
|
||||
t.Errorf("failed to create value, got err: %v", err)
|
||||
}
|
||||
|
||||
var find1 Value
|
||||
err := DB.Table("values").Joins("Nested").Joins("Nested.Join").Preload("Nested.Join.Preload").First(&find1).Error
|
||||
if err != nil {
|
||||
t.Errorf("failed to find value, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, find1, value)
|
||||
}
|
||||
|
||||
func TestEmbedPreload(t *testing.T) {
|
||||
@ -466,7 +584,7 @@ func TestEmbedPreload(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "nested address country",
|
||||
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||
preloads: map[string][]interface{}{"NestedAddress.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: EmbeddedAddress{
|
||||
|
@ -91,6 +91,65 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
||||
tx2.Commit()
|
||||
}
|
||||
|
||||
func TestPreparedStmtLruFromTransaction(t *testing.T) {
|
||||
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second})
|
||||
|
||||
tx := db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
if err := tx.Error; err != nil {
|
||||
t.Errorf("Failed to start transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
t.Errorf("Failed to run one transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
t.Errorf("Failed to commit transaction, got error %v\n", err)
|
||||
}
|
||||
|
||||
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
tx2 := db.Begin()
|
||||
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
|
||||
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
|
||||
}
|
||||
|
||||
tx2.Commit()
|
||||
// Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type.
|
||||
// If the conversion is successful, ok will be true and conn will be the converted object;
|
||||
// otherwise, ok will be false and conn will be nil.
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
// Get the number of statement keys stored in the PreparedStmtDB.
|
||||
lens := len(conn.Stmts.Keys())
|
||||
// Check if the number of stored statement keys is 0.
|
||||
if lens == 0 {
|
||||
// If the number is 0, it means there are no statements stored in the LRU cache.
|
||||
// The test fails and an error message is output.
|
||||
t.Fatalf("lru should not be empty")
|
||||
}
|
||||
// Wait for 40 seconds to give the statements in the cache enough time to expire.
|
||||
time.Sleep(time.Second * 40)
|
||||
// Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type.
|
||||
AssertEqual(t, ok, true)
|
||||
// Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds.
|
||||
// If it is not 0, it means the statements in the cache have not expired as expected.
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 0)
|
||||
|
||||
}
|
||||
|
||||
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
@ -116,9 +175,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 2)
|
||||
for _, stmt := range conn.Stmts {
|
||||
if stmt == nil {
|
||||
AssertEqual(t, len(conn.Stmts.Keys()), 2)
|
||||
for _, stmt := range conn.Stmts.Keys() {
|
||||
if stmt == "" {
|
||||
t.Fatalf("stmt cannot bee nil")
|
||||
}
|
||||
}
|
||||
@ -126,33 +185,6 @@ func TestPreparedStmtDeadlock(t *testing.T) {
|
||||
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtError(t *testing.T) {
|
||||
tx, err := OpenTestConnection(&gorm.Config{})
|
||||
AssertEqual(t, err, nil)
|
||||
|
||||
sqlDB, _ := tx.DB()
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// err prepare
|
||||
tag := Tag{Locale: "zh"}
|
||||
tx.Table("users").Find(&tag)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
AssertEqual(t, ok, true)
|
||||
AssertEqual(t, len(conn.Stmts), 0)
|
||||
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||
}
|
||||
|
||||
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
user := User{Name: "jinzhu"}
|
||||
|
||||
@ -169,10 +201,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
func TestPreparedStmtClose(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
user := *GetUser("prepared_stmt_close", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
@ -181,16 +213,77 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
if len(pdb.Stmts.Keys()) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Close()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
if len(pdb.Stmts.Keys()) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func isUsingClosedConnError(err error) bool {
|
||||
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
|
||||
return err.Error() == "sql: statement is closed"
|
||||
}
|
||||
|
||||
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
|
||||
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||
name := "prepared_stmt_concurrent_close"
|
||||
user := *GetUser(name, Config{})
|
||||
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||
if createTx.Error != nil {
|
||||
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
|
||||
}
|
||||
|
||||
// create a new connection to keep away from other tests
|
||||
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test connection due to %s", err)
|
||||
}
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
loopCount := 100
|
||||
var wg sync.WaitGroup
|
||||
var unexpectedError bool
|
||||
writerFinish := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func(id uint) {
|
||||
defer wg.Done()
|
||||
defer close(writerFinish)
|
||||
|
||||
for j := 0; j < loopCount; j++ {
|
||||
var tmp User
|
||||
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
|
||||
if err == nil || isUsingClosedConnError(err) {
|
||||
continue
|
||||
}
|
||||
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
|
||||
unexpectedError = true
|
||||
break
|
||||
}
|
||||
}(user.ID)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-writerFinish
|
||||
pdb.Close()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if unexpectedError {
|
||||
t.Fatalf("should is a unexpected error")
|
||||
}
|
||||
}
|
||||
|
@ -559,6 +559,11 @@ func TestNot(t *testing.T) {
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{})
|
||||
if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotWithAllFields(t *testing.T) {
|
||||
@ -627,6 +632,21 @@ func TestOr(t *testing.T) {
|
||||
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
sub := dryDB.Clauses(clause.Where{
|
||||
Exprs: []clause.Expression{
|
||||
clause.OrConditions{
|
||||
Exprs: []clause.Expression{
|
||||
clause.Expr{SQL: "role = ?", Vars: []interface{}{"super_admin"}},
|
||||
clause.Expr{SQL: "role = ?", Vars: []interface{}{"admin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
result = dryDB.Where(sub).Find(&User{})
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
||||
}
|
||||
|
||||
result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
|
||||
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
|
||||
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
|
||||
@ -855,6 +875,28 @@ func TestOmitWithAllFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapColumns(t *testing.T) {
|
||||
user := User{Name: "MapColumnsUser", Age: 12}
|
||||
DB.Save(&user)
|
||||
|
||||
type result struct {
|
||||
Name string
|
||||
Nickname string
|
||||
Age uint
|
||||
}
|
||||
var res result
|
||||
DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res)
|
||||
if res.Nickname != user.Name {
|
||||
t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname)
|
||||
}
|
||||
if res.Name != "" {
|
||||
t.Errorf("Expected res.Name to be empty, but got %s", res.Name)
|
||||
}
|
||||
if res.Age != user.Age {
|
||||
t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluckWithSelect(t *testing.T) {
|
||||
users := []User{
|
||||
{Name: "pluck_with_select_1", Age: 25},
|
||||
@ -1085,6 +1127,10 @@ func TestSearchWithMap(t *testing.T) {
|
||||
DB.First(&user, map[string]interface{}{"name": users[0].Name})
|
||||
CheckUser(t, user, users[0])
|
||||
|
||||
user = User{}
|
||||
DB.First(&user, map[string]interface{}{"users.name": users[0].Name})
|
||||
CheckUser(t, user, users[0])
|
||||
|
||||
user = User{}
|
||||
DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
|
||||
CheckUser(t, user, users[1])
|
||||
@ -1189,7 +1235,6 @@ func TestSubQueryWithRaw(t *testing.T) {
|
||||
Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
|
||||
Group("name"),
|
||||
).Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected to get no errors, but got %v", err)
|
||||
}
|
||||
@ -1205,7 +1250,6 @@ func TestSubQueryWithRaw(t *testing.T) {
|
||||
Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
|
||||
Group("name"),
|
||||
).Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected to get no errors, but got %v", err)
|
||||
}
|
||||
@ -1332,7 +1376,7 @@ func TestQueryResetNullValue(t *testing.T) {
|
||||
Number1 int64 `gorm:"default:NULL"`
|
||||
Number2 uint64 `gorm:"default:NULL"`
|
||||
Number3 float64 `gorm:"default:NULL"`
|
||||
Now *time.Time `gorm:"defalut:NULL"`
|
||||
Now *time.Time `gorm:"default:NULL"`
|
||||
Item1Id string
|
||||
Item1 *QueryResetItem `gorm:"references:ID"`
|
||||
Item2Id string
|
||||
@ -1422,7 +1466,7 @@ func TestQueryScanToArray(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if users[0] == nil || users[0].Name != "testname1" {
|
||||
t.Error("users[0] not covere")
|
||||
t.Error("users[0] not covered")
|
||||
}
|
||||
if users[1] != nil {
|
||||
t.Error("users[1] should be empty")
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) {
|
||||
|
||||
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||
if err != nil {
|
||||
t.Errorf("Not error should happen, got %v", err)
|
||||
t.Errorf("No error should happen, got %v", err)
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) {
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||
t.Errorf("Should find expected results")
|
||||
t.Errorf("Should find expected results, got %+v", results)
|
||||
}
|
||||
|
||||
var ages int
|
||||
@ -158,7 +159,105 @@ func TestScanRows(t *testing.T) {
|
||||
|
||||
var name string
|
||||
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
|
||||
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
|
||||
t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
|
||||
DB.Save(&User{})
|
||||
|
||||
rows, err := DB.Table("users").
|
||||
Select(`
|
||||
NULL AS bool_field,
|
||||
NULL AS int_field,
|
||||
NULL AS int8_field,
|
||||
NULL AS int16_field,
|
||||
NULL AS int32_field,
|
||||
NULL AS int64_field,
|
||||
NULL AS uint_field,
|
||||
NULL AS uint8_field,
|
||||
NULL AS uint16_field,
|
||||
NULL AS uint32_field,
|
||||
NULL AS uint64_field,
|
||||
NULL AS float32_field,
|
||||
NULL AS float64_field,
|
||||
NULL AS string_field,
|
||||
NULL AS time_field,
|
||||
NULL AS time_ptr_field,
|
||||
NULL AS embedded_int_field,
|
||||
NULL AS nested_embedded_int_field,
|
||||
NULL AS embedded_ptr_int_field
|
||||
`).Rows()
|
||||
if err != nil {
|
||||
t.Errorf("No error should happen, got %v", err)
|
||||
}
|
||||
|
||||
type NestedEmbeddedStruct struct {
|
||||
NestedEmbeddedIntField int
|
||||
NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"`
|
||||
}
|
||||
|
||||
type EmbeddedStruct struct {
|
||||
EmbeddedIntField int
|
||||
NestedEmbeddedStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
type EmbeddedPtrStruct struct {
|
||||
EmbeddedPtrIntField int
|
||||
*NestedEmbeddedStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
BoolField bool
|
||||
IntField int
|
||||
Int8Field int8
|
||||
Int16Field int16
|
||||
Int32Field int32
|
||||
Int64Field int64
|
||||
UIntField uint
|
||||
UInt8Field uint8
|
||||
UInt16Field uint16
|
||||
UInt32Field uint32
|
||||
UInt64Field uint64
|
||||
Float32Field float32
|
||||
Float64Field float64
|
||||
StringField string
|
||||
TimeField time.Time
|
||||
TimePtrField *time.Time
|
||||
EmbeddedStruct `gorm:"embedded"`
|
||||
*EmbeddedPtrStruct `gorm:"embedded"`
|
||||
}
|
||||
|
||||
currTime := time.Now()
|
||||
reusedVar := Result{
|
||||
BoolField: true,
|
||||
IntField: 1,
|
||||
Int8Field: 1,
|
||||
Int16Field: 1,
|
||||
Int32Field: 1,
|
||||
Int64Field: 1,
|
||||
UIntField: 1,
|
||||
UInt8Field: 1,
|
||||
UInt16Field: 1,
|
||||
UInt32Field: 1,
|
||||
UInt64Field: 1,
|
||||
Float32Field: 1.1,
|
||||
Float64Field: 1.1,
|
||||
StringField: "hello",
|
||||
TimeField: currTime,
|
||||
TimePtrField: &currTime,
|
||||
EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||
EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
if err := DB.ScanRows(rows, &reusedVar); err != nil {
|
||||
t.Errorf("should get no error, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(reusedVar, Result{}) {
|
||||
t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ type SerializerPostgresStruct struct {
|
||||
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
|
||||
|
||||
func adaptorSerializerModel(s *SerializerStruct) interface{} {
|
||||
if DB.Dialector.Name() == "postgres" {
|
||||
if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
|
||||
sps := SerializerPostgresStruct(*s)
|
||||
return &sps
|
||||
}
|
||||
|
@ -487,7 +487,7 @@ func replaceQuoteInSQL(sql string) string {
|
||||
|
||||
// convert dialect special quote into double quote
|
||||
switch DB.Dialector.Name() {
|
||||
case "postgres":
|
||||
case "postgres", "gaussdb":
|
||||
sql = strings.ReplaceAll(sql, `"`, `"`)
|
||||
case "mysql", "sqlite":
|
||||
sql = strings.ReplaceAll(sql, "`", `"`)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/driver/gaussdb"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
@ -251,6 +252,82 @@ func TestPostgresTableWithIdentifierLength(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGaussDBTableWithIdentifierLength(t *testing.T) {
|
||||
if DB.Dialector.Name() != "gaussdb" {
|
||||
return
|
||||
}
|
||||
|
||||
type LongString struct {
|
||||
ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
|
||||
}
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
|
||||
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||
}
|
||||
|
||||
constraints := user.ParseUniqueConstraints()
|
||||
if len(constraints) != 1 {
|
||||
t.Fatalf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
|
||||
for key := range constraints {
|
||||
if len(key) != 63 {
|
||||
t.Errorf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("naming strategy", func(t *testing.T) {
|
||||
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{},
|
||||
})
|
||||
|
||||
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||
}
|
||||
|
||||
constraints := user.ParseUniqueConstraints()
|
||||
if len(constraints) != 1 {
|
||||
t.Fatalf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
|
||||
for key := range constraints {
|
||||
if len(key) != 63 {
|
||||
t.Errorf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("namer", func(t *testing.T) {
|
||||
uname := "custom_unique_name"
|
||||
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
|
||||
NamingStrategy: mockUniqueNamingStrategy{
|
||||
UName: uname,
|
||||
},
|
||||
})
|
||||
|
||||
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse user unique, got error %v", err)
|
||||
}
|
||||
|
||||
constraints := user.ParseUniqueConstraints()
|
||||
if len(constraints) != 1 {
|
||||
t.Fatalf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
|
||||
for key := range constraints {
|
||||
if key != uname {
|
||||
t.Errorf("failed to find unique constraint, got %v", constraints)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type mockUniqueNamingStrategy struct {
|
||||
UName string
|
||||
schema.NamingStrategy
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash -e
|
||||
|
||||
dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
|
||||
dialects=("sqlite" "mysql" "postgres" "gaussdb" "sqlserver" "tidb")
|
||||
|
||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||
cd ..
|
||||
@ -16,21 +16,22 @@ then
|
||||
fi
|
||||
|
||||
# SqlServer for Mac M1
|
||||
if [[ -z $GITHUB_ACTION ]]; then
|
||||
if [ -d tests ]
|
||||
then
|
||||
cd tests
|
||||
if [[ $(uname -a) == *" arm64" ]]; then
|
||||
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
|
||||
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
|
||||
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
|
||||
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
|
||||
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
|
||||
else
|
||||
docker-compose start
|
||||
fi
|
||||
cd ..
|
||||
if [[ -z $GITHUB_ACTION && -d tests ]]; then
|
||||
cd tests
|
||||
if [[ $(uname -a) == *" arm64" ]]; then
|
||||
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d --wait
|
||||
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
|
||||
for query in \
|
||||
"IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" \
|
||||
"IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" \
|
||||
"IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];"
|
||||
do
|
||||
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "$query" > /dev/null || true
|
||||
done
|
||||
else
|
||||
MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d --wait
|
||||
fi
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
//go:debug x509negativeserial=1
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
@ -7,6 +8,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/gaussdb"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
@ -20,7 +22,8 @@ var DB *gorm.DB
|
||||
var (
|
||||
mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
|
||||
postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
|
||||
gaussdbDSN = "user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master"
|
||||
tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
|
||||
)
|
||||
|
||||
@ -64,6 +67,15 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
|
||||
DSN: dbDSN,
|
||||
PreferSimpleProtocol: true,
|
||||
}), cfg)
|
||||
case "gaussdb":
|
||||
log.Println("testing gaussdb...")
|
||||
if dbDSN == "" {
|
||||
dbDSN = gaussdbDSN
|
||||
}
|
||||
db, err = gorm.Open(gaussdb.New(gaussdb.Config{
|
||||
DSN: dbDSN,
|
||||
PreferSimpleProtocol: true,
|
||||
}), cfg)
|
||||
case "sqlserver":
|
||||
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
|
||||
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -67,7 +68,7 @@ func TestTransaction(t *testing.T) {
|
||||
return tx5.First(&User{}, "name = ?", "transaction-2").Error
|
||||
})
|
||||
}); err != nil {
|
||||
t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
|
||||
t.Fatalf("prepare statement and nested transaction coexist" + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -297,6 +298,74 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
|
||||
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return callback(ctx, tx)
|
||||
})
|
||||
}
|
||||
var (
|
||||
user = *GetUser("transaction-nested", Config{})
|
||||
user1 = *GetUser("transaction-nested-1", Config{})
|
||||
user2 = *GetUser("transaction-nested-2", Config{})
|
||||
)
|
||||
|
||||
if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
|
||||
tx.Create(&user)
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
|
||||
tx1.Create(&user1)
|
||||
|
||||
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
|
||||
tx2.Create(&user2)
|
||||
|
||||
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
return errors.New("inner rollback")
|
||||
}); err == nil {
|
||||
t.Fatalf("nested transaction has no error")
|
||||
}
|
||||
|
||||
return errors.New("rollback")
|
||||
}); err == nil {
|
||||
t.Fatalf("nested transaction should returns error")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||
t.Fatalf("Should not find rollbacked record")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("no error should return, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||
t.Fatalf("Should not find rollbacked parent record")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should not find rollbacked nested record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledNestedTransaction(t *testing.T) {
|
||||
var (
|
||||
user = *GetUser("transaction-nested", Config{})
|
||||
@ -391,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
|
||||
return tx2.Scan(&User{}).Error
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -405,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
|
||||
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionWithDefaultTimeout(t *testing.T) {
|
||||
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database, got error %v", err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
time.Sleep(3 * time.Second)
|
||||
if err = tx.Find(&User{}).Error; err == nil {
|
||||
t.Errorf("should return error when transaction timeout, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// only sqlite, postgres, sqlserver support returning
|
||||
// only sqlite, postgres, gaussdb, sqlserver support returning
|
||||
func TestUpdateReturning(t *testing.T) {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -883,9 +883,9 @@ func TestSaveWithHooks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// only postgres, sqlserver, sqlite support update from
|
||||
// only postgres, gaussdb, sqlserver, sqlite support update from
|
||||
func TestUpdateFrom(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" {
|
||||
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -32,12 +32,16 @@ func sourceDir(file string) string {
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
func FileWithLineNum() string {
|
||||
// the second caller usually from gorm internal, so set i start from 2
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) &&
|
||||
!strings.HasSuffix(file, ".gen.go") {
|
||||
return file + ":" + strconv.FormatInt(int64(line), 10)
|
||||
pcs := [13]uintptr{}
|
||||
// the third caller usually from gorm internal
|
||||
len := runtime.Callers(3, pcs[:])
|
||||
frames := runtime.CallersFrames(pcs[:len])
|
||||
for i := 0; i < len; i++ {
|
||||
// second return value is "more", not "ok"
|
||||
frame, _ := frames.Next()
|
||||
if (!strings.HasPrefix(frame.File, gormSourceDir) ||
|
||||
strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") {
|
||||
return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10))
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,3 +166,14 @@ func SplitNestedRelationName(name string) []string {
|
||||
func JoinNestedRelationNames(relationNames []string) string {
|
||||
return strings.Join(relationNames, nestedRelationSplit)
|
||||
}
|
||||
|
||||
// RTrimSlice Right trims the given slice by given length
|
||||
func RTrimSlice[T any](v []T, trimLen int) []T {
|
||||
if trimLen >= len(v) { // trimLen greater than slice len means fully sliced
|
||||
return v[:0]
|
||||
}
|
||||
if trimLen < 0 { // negative trimLen is ignored
|
||||
return v[:]
|
||||
}
|
||||
return v[:len(v)-trimLen]
|
||||
}
|
||||
|
@ -138,3 +138,64 @@ func TestToString(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRTrimSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []int
|
||||
trimLen int
|
||||
expected []int
|
||||
}{
|
||||
{
|
||||
name: "Trim two elements from end",
|
||||
input: []int{1, 2, 3, 4, 5},
|
||||
trimLen: 2,
|
||||
expected: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "Trim entire slice",
|
||||
input: []int{1, 2, 3},
|
||||
trimLen: 3,
|
||||
expected: []int{},
|
||||
},
|
||||
{
|
||||
name: "Trim length greater than slice length",
|
||||
input: []int{1, 2, 3},
|
||||
trimLen: 5,
|
||||
expected: []int{},
|
||||
},
|
||||
{
|
||||
name: "Zero trim length",
|
||||
input: []int{1, 2, 3},
|
||||
trimLen: 0,
|
||||
expected: []int{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "Trim one element from end",
|
||||
input: []int{1, 2, 3},
|
||||
trimLen: 1,
|
||||
expected: []int{1, 2},
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
input: []int{},
|
||||
trimLen: 2,
|
||||
expected: []int{},
|
||||
},
|
||||
{
|
||||
name: "Negative trim length (should be treated as zero)",
|
||||
input: []int{1, 2, 3},
|
||||
trimLen: -1,
|
||||
expected: []int{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range tests {
|
||||
t.Run(testcase.name, func(t *testing.T) {
|
||||
result := RTrimSlice(testcase.input, testcase.trimLen)
|
||||
if !AssertEqual(result, testcase.expected) {
|
||||
t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user