Compare commits

..

1 Commits

Author SHA1 Message Date
qqxhb
cd9a163fd8 feat: unsafe pointer 2023-08-29 16:57:31 +08:00
104 changed files with 989 additions and 7402 deletions

View File

@ -1,20 +0,0 @@
name-template: 'v Release $NEXT_PATCH_VERSION 🌈'
tag-template: 'v$NEXT_PATCH_VERSION'
categories:
- title: '🚀 Features'
labels:
- 'feature'
- 'enhancement'
- title: '🐛 Bug Fixes'
labels:
- 'fix'
- 'bugfix'
- 'bug'
- title: '🧰 Maintenance'
label: 'chore'
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
change-title-escapes: '\<*_&'
template: |
## Changes
$CHANGES

View File

@ -1,31 +0,0 @@
name: Create Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
pull-requests: read
jobs:
create_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Generate Release Notes and Publish
id: generate_release_notes
uses: release-drafter/release-drafter@v6
with:
config-name: 'release-drafter.yml'
name: "Release ${{ github.ref_name }}"
tag: ${{ github.ref_name }}
publish: true
prerelease: false
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,26 +0,0 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -11,7 +11,7 @@ jobs:
name: Label issues and pull requests
steps:
- name: check out
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: labeler
uses: jinzhu/super-labeler-action@develop

22
.github/workflows/reviewdog.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2
- name: Setup reviewdog
uses: reviewdog/action-setup@v1
- name: gofumpt -s with reviewdog
env:
REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
go install mvdan.cc/gofumpt@v0.2.0
gofumpt -e -d . | \
reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review

View File

@ -16,7 +16,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.23', '1.24']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
@ -27,10 +27,10 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -41,8 +41,8 @@ jobs:
mysql:
strategy:
matrix:
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.23', '1.24']
dbversion: ['mysql:latest', 'mysql:5.7']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -70,10 +70,10 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -85,7 +85,7 @@ jobs:
strategy:
matrix:
dbversion: [ 'mariadb:latest' ]
go: ['1.23', '1.24']
go: ['1.21', '1.20', '1.19']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
@ -113,10 +113,10 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -127,8 +127,8 @@ jobs:
postgres:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
go: ['1.23', '1.24']
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
@ -156,10 +156,10 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
@ -170,21 +170,23 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.23', '1.24']
go: ['1.21', '1.20', '1.19']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}
services:
mssql:
image: mcr.microsoft.com/mssql/server:2022-latest
image: mcmoe/mssqldocker:latest
env:
TZ: Asia/Shanghai
ACCEPT_EULA: Y
MSSQL_SA_PASSWORD: LoremIpsum86
SA_PASSWORD: LoremIpsum86
MSSQL_DB: gorm
MSSQL_USER: gorm
MSSQL_PASSWORD: LoremIpsum86
ports:
- 9930:1433
options: >-
--health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1"
--health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1"
--health-start-period 10s
--health-interval 10s
--health-timeout 5s
@ -197,22 +199,22 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh
run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh
tidb:
strategy:
matrix:
dbversion: [ 'v6.5.0' ]
go: ['1.23', '1.24']
go: ['1.21', '1.20', '1.19']
platform: [ ubuntu-latest ]
runs-on: ${{ matrix.platform }}
@ -229,82 +231,14 @@ jobs:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: go mod package cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

View File

@ -1,9 +1,7 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- exportloopref
- gocritic
- gosec
- ineffassign
@ -11,9 +9,12 @@ linters:
- prealloc
- unconvert
- unparam
- goimports
- whitespace
formatters:
enable:
- gofumpt
- goimports
linters-settings:
whitespace:
multi-func: true
goimports:
local-prefixes: gorm.io/gorm

View File

@ -1,128 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to participate in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community includes:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period. This
includes avoiding interactions in community spaces and external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any interaction or public
communication with the community for a specified period. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2013-present Jinzhu <wosmvp@gmail.com>
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)

View File

@ -396,10 +396,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}
}
case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
@ -437,10 +433,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
if !rv.CanAddr() {
association.Error = ErrInvalidValue
return
}
appendToFieldValues(rv.Addr())
}
@ -518,9 +510,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
if association.Error != nil {
return
}
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
@ -542,9 +531,6 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
if association.Error != nil {
return
}
}
if len(values) > 0 {

View File

@ -187,18 +187,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
@ -347,14 +339,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return
}
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@ -47,7 +47,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
)
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
@ -126,7 +126,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
)
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
@ -195,7 +195,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
@ -268,11 +268,11 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PointerTo(fieldType)
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {

View File

@ -53,16 +53,12 @@ func Create(config *Config) func(db *gorm.DB) {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
@ -93,10 +89,6 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -111,70 +103,13 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil &&
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
@ -187,10 +122,10 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
@ -200,16 +135,16 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
@ -318,18 +253,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
@ -349,7 +282,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
@ -378,7 +311,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}

View File

@ -1,71 +0,0 @@
package callbacks
import (
"reflect"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
var schemaCache = &sync.Map{}
func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}
s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}
stmt.Schema = s
values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}

View File

@ -157,14 +157,8 @@ func Delete(config *Config) func(db *gorm.DB) {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -172,10 +166,6 @@ func Delete(config *Config) func(db *gorm.DB) {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
}
}

View File

@ -1,157 +0,0 @@
package callbacks
import (
"reflect"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}
func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}
}

View File

@ -3,7 +3,6 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
@ -75,7 +74,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
@ -83,105 +82,27 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
if relationships == nil {
return nil
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, join1)
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
preloadMap := parsePreloadMap(s, preloads)
for name := range preloadMap {
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue = reflect.Append(reflectValue, frv.Addr())
} else {
if frv.IsNil() {
continue
}
reflectValue = reflect.Append(reflectValue, frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
return err
}
}
case reflect.Struct, reflect.Pointer:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = tx.Statement.ReflectValue
@ -275,8 +196,6 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
@ -285,11 +204,7 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
}

View File

@ -3,6 +3,7 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
@ -25,10 +26,6 @@ func Query(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
@ -114,7 +111,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
@ -128,12 +125,12 @@ func BuildQuerySQL(db *gorm.DB) {
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
guessNestedRelations = append(guessNestedRelations, relation)
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
@ -143,13 +140,18 @@ func BuildQuerySQL(db *gorm.DB) {
if isNestedJoin {
isRelations = true
relations = guessNestedRelations
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
@ -166,13 +168,6 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
@ -232,24 +227,19 @@ func BuildQuerySQL(db *gorm.DB) {
}
parentTableName := clause.CurrentTable
for idx, rel := range relations {
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
curAliasName := rel.Name
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
@ -264,6 +254,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
@ -281,27 +272,38 @@ func Preload(db *gorm.DB) {
return
}
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
preloadDB.Statement.Unscoped = db.Statement.Unscoped
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
for _, name := range preloadNames {
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause := db.Statement.Clauses["FROM"]
fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
db.Statement.Clauses["FROM"] = fromClause
}
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {

View File

@ -13,10 +13,5 @@ func RawExec(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}

View File

@ -92,10 +92,6 @@ func Update(config *Config) func(db *gorm.DB) {
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -103,11 +99,6 @@ func Update(config *Config) func(db *gorm.DB) {
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
@ -243,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
@ -277,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {

View File

@ -0,0 +1,36 @@
package callbacks
import (
"reflect"
"testing"
)
func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}
t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}
vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}
// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

View File

@ -185,13 +185,6 @@ func (db *DB) Omit(columns ...string) (tx *DB) {
return
}
// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields
func (db *DB) MapColumns(m map[string]string) (tx *DB) {
tx = db.getInstance()
tx.Statement.ColumnMapping = m
return
}
// Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
@ -306,16 +299,10 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{
// {Column: clause.Column{Name: "name"}, Desc: true},
// {Column: clause.Column{Name: "age"}, Desc: true},
// }})
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderBy:
tx.Statement.AddClause(v)
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
@ -380,12 +367,33 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
}
func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
if len(scopes) == 0 {
return tx
}
return db
tx.Statement.scopes = nil
conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}
for _, condition := range conditions {
tx.Statement.AddClause(condition)
}
return tx
}
// Preload preload associations with given conditions
@ -442,16 +450,6 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
return
}
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true

View File

@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) {
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = name[:0]
name = []byte{}
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {

View File

@ -1,7 +1,5 @@
package clause
import "gorm.io/gorm/utils"
type JoinType string
const (
@ -11,30 +9,6 @@ const (
RightJoin JoinType = "RIGHT"
)
type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct {
Type JoinType
@ -44,12 +18,6 @@ type Join struct {
Expression Expression
}
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)

View File

@ -1,5 +1,7 @@
package clause
import "strconv"
// Limit limit clause
type Limit struct {
Limit *int
@ -15,14 +17,14 @@ func (limit Limit) Name() string {
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit)
builder.WriteString(strconv.Itoa(*limit.Limit))
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
builder.WriteString(strconv.Itoa(limit.Offset))
}
}

View File

@ -22,53 +22,43 @@ func TestLimit(t *testing.T) {
Limit: &limit10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
"SELECT * FROM `users` LIMIT 0", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit0},
"SELECT * FROM `users` LIMIT 0", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{20},
"SELECT * FROM `users` OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 20},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit10, 30},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT ?",
[]interface{}{limit10},
"SELECT * FROM `users` LIMIT 10", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET ?",
[]interface{}{30},
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT ? OFFSET ?",
[]interface{}{limit50, 30},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
},
}

View File

@ -1,12 +1,5 @@
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct {
Strength string
Table Table

View File

@ -14,21 +14,17 @@ func TestLocking(t *testing.T) {
Vars []interface{}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}},
"SELECT * FROM `users` FOR UPDATE", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "SHARE", Table: clause.Table{Name: clause.CurrentTable}}},
"SELECT * FROM `users` FOR SHARE OF `users`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: "UPDATE"}, clause.Locking{Strength: "UPDATE", Options: "NOWAIT"}},
"SELECT * FROM `users` FOR UPDATE NOWAIT", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}},
"SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil,
},
}
for idx, result := range results {

View File

@ -26,12 +26,9 @@ func (returning Returning) Build(builder Builder) {
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 {
if v.Columns != nil {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
} else {
returning.Columns = nil
}
}
clause.Expression = returning
}

View File

@ -26,22 +26,6 @@ func TestReturning(t *testing.T) {
}},
"SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}},
"SELECT * FROM `users` RETURNING *", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Returning{
[]clause.Column{clause.PrimaryColumn},
}, clause.Returning{
[]clause.Column{{Name: "name"}, {Name: "age"}},
}, clause.Returning{}},
"SELECT * FROM `users` RETURNING *", nil,
},
}
for idx, result := range results {

View File

@ -21,12 +21,6 @@ func (where Where) Name() string {
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
@ -153,11 +147,6 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
@ -166,15 +155,6 @@ type NotConditions struct {
}
func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
}
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
@ -207,39 +187,4 @@ func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}
}

View File

@ -63,7 +63,7 @@ func TestWhere(t *testing.T) {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))},
}},
"SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?",
"SELECT * FROM `users` WHERE (`age` = ? OR `name` <> ?)",
[]interface{}{18, "jinzhu"},
},
{
@ -94,7 +94,7 @@ func TestWhere(t *testing.T) {
clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})),
},
}},
"SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?",
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `score` <= ?)",
[]interface{}{"1", 100},
},
{
@ -105,30 +105,6 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{
clause.Not(clause.AndConditions{
Exprs: []clause.Expression{
clause.Eq{Column: clause.PrimaryColumn, Value: "1"},
clause.Gt{Column: "age", Value: 18},
}}, clause.OrConditions{
Exprs: []clause.Expression{
clause.Lt{Column: "score", Value: 100},
},
}),
}}},
"SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)",
[]interface{}{"1", 18, 100},
},
}
for idx, result := range results {

View File

@ -49,6 +49,4 @@ var (
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
)

View File

@ -1,11 +1,9 @@
package gorm
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/maphash"
"reflect"
"strings"
@ -378,12 +376,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
@ -625,15 +619,14 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%d", spID))
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
@ -674,18 +667,11 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0]
}
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}

View File

@ -1,605 +0,0 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

1
go.mod
View File

@ -5,5 +5,4 @@ go 1.18
require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.1.5
golang.org/x/text v0.20.0
)

2
go.sum
View File

@ -2,5 +2,3 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=

49
gorm.go
View File

@ -8,6 +8,7 @@ import (
"sort"
"sync"
"time"
"unsafe"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
@ -22,8 +23,6 @@ type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
@ -36,11 +35,6 @@ type Config struct {
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
@ -57,8 +51,6 @@ type Config struct {
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
@ -119,7 +111,6 @@ type Session struct {
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
@ -137,24 +128,12 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
return isConfig && !isConfig2
})
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
@ -202,25 +181,16 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
}
}
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
@ -272,10 +242,6 @@ func (db *DB) Session(config *Session) *DB {
txConfig.FullSaveAssociations = true
}
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
@ -291,7 +257,7 @@ func (db *DB) Session(config *Session) *DB {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
preparedStmt = NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
@ -414,7 +380,7 @@ func (db *DB) DB() (*sql.DB, error) {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
return (*sql.DB)(unsafe.Pointer(reflect.ValueOf(tx).Elem().FieldByName("db").Addr().Pointer())), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
@ -444,9 +410,6 @@ func (db *DB) getInstance() *DB {
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
@ -537,7 +500,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

View File

@ -1,493 +0,0 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -1,183 +0,0 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -69,7 +69,7 @@ type Interface interface {
}
var (
// Discard logger will print any log to io.Discard
// Discard Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
@ -78,13 +78,8 @@ var (
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
// Recorder Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
// RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation
RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
return sql, params
}
)
// New initialize logger
@ -134,30 +129,28 @@ func (l *logger) LogMode(level LogLevel) Interface {
}
// Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
@ -189,8 +182,8 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
// Trace print sql message
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
@ -205,8 +198,8 @@ type traceRecorder struct {
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
// New new trace recorder
func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
@ -216,10 +209,3 @@ func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (s
l.SQL, l.RowsAffected = fc()
l.Err = err
}
func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if RecorderParamsFilter == nil {
return sql, params
}
return RecorderParamsFilter(ctx, sql, params...)
}

View File

@ -34,19 +34,6 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
@ -92,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
@ -113,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@ -123,12 +110,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
@ -136,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return
}
}
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
}
}
}

View File

@ -31,24 +31,20 @@ func (s ExampleStruct) Value() (driver.Value, error) {
}
func format(v []byte, escaper string) string {
return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
}
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
type intType int
type floatType float64
var (
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
pwd = password([]byte("pass"))
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
)
results := []struct {
@ -61,13 +57,13 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`,
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)",
@ -91,37 +87,25 @@ func TestExplainSQL(t *testing.T) {
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
}

View File

@ -87,8 +87,6 @@ type Migrator interface {
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)

View File

@ -7,7 +7,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
@ -28,8 +27,6 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct
type Migrator struct {
Config
@ -94,6 +91,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL += " NOT NULL"
}
if field.Unique {
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
@ -107,31 +108,21 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return
}
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
queryTx := m.DB.Session(&gorm.Session{})
execTx := queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
queryTx, execTx := m.GetQueryAndExecTx()
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
@ -216,11 +207,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
@ -231,7 +217,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
@ -280,7 +266,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := constraint.Build()
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
@ -288,11 +274,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
}
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
@ -373,9 +354,6 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
@ -395,11 +373,9 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
}
return m.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
@ -410,7 +386,6 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
@ -419,7 +394,6 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
).Error
}
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
@ -430,11 +404,9 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
@ -448,7 +420,6 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
@ -456,7 +427,6 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
}
return m.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
@ -467,13 +437,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
@ -512,19 +479,8 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
}
}
// check precision
if realDataType == "decimal" || realDataType == "numeric" &&
regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
precision, scale, ok := columnType.DecimalSize()
if ok {
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
alterColumn = true
}
}
} else {
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
@ -534,8 +490,16 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & current database is non-nullable(to be nullable)
if !field.PrimaryKey && !nullable {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
@ -550,19 +514,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal
// not both null
if currentDefaultNotNull || dvNotNull {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
}
}
}
@ -574,39 +532,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
if alterColumn {
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.DBName)
}
return nil
}
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
@ -676,36 +608,37 @@ func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
//
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil {
return nil, stmt.Table
return nil, nil, stmt.Table
}
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
return nil, &chk, stmt.Table
}
getTable := func(rel *schema.Relationship) string {
@ -720,7 +653,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, getTable(rel)
return constraint, nil, getTable(rel)
}
}
@ -728,39 +661,40 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
return nil, &v, stmt.Table
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, getTable(rel)
return constraint, nil, getTable(rel)
}
}
}
return nil, stmt.Schema.Table
return nil, nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if chk != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
if constraint != nil {
vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := constraint.Build()
sql, values := buildConstraint(constraint)
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}
return nil
})
}
@ -768,9 +702,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
})
@ -781,9 +717,11 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.DB.Raw(
@ -825,9 +763,6 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@ -860,11 +795,9 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
})
@ -875,11 +808,9 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Raw(
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",

View File

@ -3,39 +3,33 @@ package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts stmt_store.Store
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
}
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
@ -48,41 +42,84 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
}
}
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, stmt.Error()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, stmt.Error()
}
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -110,8 +147,11 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return result, err
@ -121,8 +161,12 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Stmts.Delete(query)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return rows, err
@ -136,14 +180,6 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
@ -171,8 +207,12 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return result, err
@ -182,8 +222,12 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Stmts.Delete(query)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return rows, err
@ -196,11 +240,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

39
scan.go
View File

@ -4,7 +4,6 @@ import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"time"
"gorm.io/gorm/schema"
@ -16,7 +15,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
@ -24,7 +23,7 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
@ -132,15 +131,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
if len(db.Statement.ColumnMapping) > 0 {
for i, column := range columns {
v, ok := db.Statement.ColumnMapping[column]
if ok {
columns[i] = v
}
}
}
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
@ -245,14 +235,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
@ -262,7 +244,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// latest name is raw dbname
// lastest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
@ -275,11 +257,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
continue
}
}
var val interface{}
values[idx] = &val
values[idx] = &sql.RawBytes{}
} else {
var val interface{}
values[idx] = &val
values[idx] = &sql.RawBytes{}
}
}
}
@ -294,18 +274,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
}
for initialized || rows.Next() {
BEGIN:
@ -349,9 +325,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
}
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:

35
schema/check.go Normal file
View File

@ -0,0 +1,35 @@
package schema
import (
"regexp"
"strings"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
type Check struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]Check {
checks := map[string]Check{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = Check{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = Check{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}

View File

@ -6,7 +6,6 @@ import (
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserCheck struct {
@ -21,7 +20,7 @@ func TestParseCheck(t *testing.T) {
t.Fatalf("failed to parse user check, got error %v", err)
}
results := map[string]schema.CheckConstraint{
results := map[string]schema.Check{
"name_checker": {
Name: "name_checker",
Constraint: "name <> 'jinzhu'",
@ -54,31 +53,3 @@ func TestParseCheck(t *testing.T) {
}
}
}
func TestParseUniqueConstraints(t *testing.T) {
type UserUnique struct {
Name1 string `gorm:"unique"`
Name2 string `gorm:"uniqueIndex"`
}
user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
results := map[string]schema.UniqueConstraint{
"uni_user_uniques_name1": {
Name: "uni_user_uniques_name1",
Field: &schema.Field{Name: "Name1", Unique: true},
},
}
for k, result := range results {
v, ok := constraints[k]
if !ok {
t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints)
}
tests.AssertObjEqual(t, result, v, "Name")
tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex")
}
}

View File

@ -1,66 +0,0 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

View File

@ -49,14 +49,11 @@ const (
Bytes DataType = "bytes"
)
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field
type Field struct {
Name string
DBName string
BindNames []string
EmbeddedBindNames []string
DataType DataType
GORMDataType DataType
PrimaryKey bool
@ -90,12 +87,6 @@ type Field struct {
Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}
func (field *Field) BindName() string {
@ -113,7 +104,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
EmbeddedBindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
@ -129,7 +119,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
AutoIncrementIncrement: 1,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
@ -318,10 +308,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if val, ok := field.TagSettings["TYPE"]; ok {
lowerVal := DataType(strings.ToLower(val))
switch lowerVal {
switch DataType(strings.ToLower(val)) {
case Bool, Int, Uint, Float, String, Time, Bytes:
field.DataType = lowerVal
field.DataType = DataType(strings.ToLower(val))
default:
field.DataType = DataType(val)
}
@ -406,9 +395,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
}
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
@ -448,30 +434,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
func (field *Field) setupValuerAndSetter() {
// Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
fieldValue := v.Field(fieldIndex)
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
@ -513,20 +490,13 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
// ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
return v.Field(fieldIndex)
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
@ -686,7 +656,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
@ -695,7 +665,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
@ -760,7 +730,7 @@ func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
}
@ -1013,6 +983,6 @@ func (field *Field) setupNewValuePool() {
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType))
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
}

View File

@ -14,7 +14,7 @@ type Index struct {
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
Fields []IndexOption
}
type IndexOption struct {
@ -23,13 +23,12 @@ type IndexOption struct {
Sort string // DESC, ASC
Collate string
Length int
Priority int
priority int
}
// ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() []*Index {
indexesByName := map[string]*Index{}
indexes := []*Index{}
func (schema *Schema) ParseIndexes() map[string]Index {
indexes := map[string]Index{}
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
@ -39,12 +38,7 @@ func (schema *Schema) ParseIndexes() []*Index {
break
}
for _, index := range fieldIndexes {
idx := indexesByName[index.Name]
if idx == nil {
idx = &Index{Name: index.Name}
indexesByName[index.Name] = idx
indexes = append(indexes, idx)
}
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
@ -64,14 +58,16 @@ func (schema *Schema) ParseIndexes() []*Index {
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].Priority < idx.Fields[j].Priority
return idx.Fields[i].priority < idx.Fields[j].priority
})
indexes[index.Name] = idx
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
index.Fields[0].Field.Unique = true
}
}
return indexes
@ -82,12 +78,12 @@ func (schema *Schema) LookIndex(name string) *Index {
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return index
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return index
return &index
}
}
}
@ -105,7 +101,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.IndexByte(tag, ',')
idx = strings.Index(tag, ",")
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
@ -115,14 +111,17 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"the composite tag of %s.%s cannot be empty",
"The composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
@ -155,7 +154,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
Priority: priority,
priority: priority,
}},
})
}

View File

@ -1,11 +1,11 @@
package schema_test
import (
"reflect"
"sync"
"testing"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
type UserIndex struct {
@ -19,10 +19,6 @@ type UserIndex struct {
OID int64 `gorm:"index:idx_id;index:idx_oid,unique"`
MemberNumber string `gorm:"index:idx_id,priority:1"`
Name7 string `gorm:"index:type"`
Name8 string `gorm:"index:,length:10;index:,collate:utf8"`
CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"`
CompName2 string `gorm:"index:,composite:idx_compname_1"`
// Composite Index: Flattened structure.
Data0A string `gorm:"index:,composite:comp_id0"`
@ -61,17 +57,17 @@ func TestParseIndex(t *testing.T) {
t.Fatalf("failed to parse user index, got error %v", err)
}
results := []*schema.Index{
{
results := map[string]schema.Index{
"idx_user_indices_name": {
Name: "idx_user_indices_name",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}},
},
{
"idx_name": {
Name: "idx_name",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", Unique: true}}},
},
{
"idx_user_indices_name3": {
Name: "idx_user_indices_name3",
Type: "btree",
Where: "name3 != 'jinzhu'",
@ -82,19 +78,19 @@ func TestParseIndex(t *testing.T) {
Length: 10,
}},
},
{
"idx_user_indices_name4": {
Name: "idx_user_indices_name4",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", Unique: true}}},
},
{
"idx_user_indices_name5": {
Name: "idx_user_indices_name5",
Class: "FULLTEXT",
Comment: "hello , world",
Where: "age > 10",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}},
},
{
"profile": {
Name: "profile",
Comment: "hello , world",
Where: "age > 10",
@ -104,39 +100,21 @@ func TestParseIndex(t *testing.T) {
Expression: "ABS(age)",
}},
},
{
"idx_id": {
Name: "idx_id",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", Unique: true}}},
},
{
"idx_oid": {
Name: "idx_oid",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}},
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", Unique: true}}},
},
{
"type": {
Name: "type",
Type: "",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}},
},
{
Name: "idx_user_indices_name8",
Type: "",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "Name8"}, Length: 10},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "Name8"}, Collate: "utf8"},
},
},
{
Class: "UNIQUE",
Name: "idx_user_indices_idx_compname_1",
Option: "NULLS NOT DISTINCT",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "CompName1", NotNull: true}},
{Field: &schema.Field{Name: "CompName2"}},
},
},
{
"idx_user_indices_comp_id0": {
Name: "idx_user_indices_comp_id0",
Type: "",
Fields: []schema.IndexOption{{
@ -145,7 +123,7 @@ func TestParseIndex(t *testing.T) {
Field: &schema.Field{Name: "Data0B"},
}},
},
{
"idx_user_indices_comp_id1": {
Name: "idx_user_indices_comp_id1",
Fields: []schema.IndexOption{{
Field: &schema.Field{Name: "Data1A"},
@ -155,7 +133,7 @@ func TestParseIndex(t *testing.T) {
Field: &schema.Field{Name: "Data1C"},
}},
},
{
"idx_user_indices_comp_id2": {
Name: "idx_user_indices_comp_id2",
Class: "UNIQUE",
Fields: []schema.IndexOption{{
@ -168,108 +146,40 @@ func TestParseIndex(t *testing.T) {
},
}
CheckIndices(t, results, user.ParseIndexes())
}
indices := user.ParseIndexes()
func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) {
type IndexTest struct {
FieldA string `gorm:"unique;index"` // unique and index
FieldB string `gorm:"unique"` // unique
FieldC string `gorm:"index:,unique"` // uniqueIndex
FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index
FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex
FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"`
FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index
FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"`
FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex
FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex
}
indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user index, got error %v", err)
}
indices := indexSchema.ParseIndexes()
expectedIndices := []*schema.Index{
{
Name: "idx_index_tests_field_a",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}},
},
{
Name: "idx_index_tests_field_c",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}},
},
{
Name: "idx_index_tests_field_d",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldD"}},
// Note: Duplicate Columns
{Field: &schema.Field{Name: "FieldD"}},
},
},
{
Name: "uniq_field_e1_e2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldE1"}},
{Field: &schema.Field{Name: "FieldE2"}},
},
},
{
Name: "uniq_field_f1_f2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldF1"}},
{Field: &schema.Field{Name: "FieldF2"}},
},
},
{
Name: "idx_index_tests_field_f1",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}},
},
{
Name: "idx_index_tests_field_g",
Class: "UNIQUE",
Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}},
},
{
Name: "uniq_field_h1_h2",
Class: "UNIQUE",
Fields: []schema.IndexOption{
{Field: &schema.Field{Name: "FieldH1", Unique: true}},
{Field: &schema.Field{Name: "FieldH2"}},
},
},
}
CheckIndices(t, expectedIndices, indices)
}
func CheckIndices(t *testing.T, expected, actual []*schema.Index) {
if len(expected) != len(actual) {
t.Errorf("expected %d indices, but got %d", len(expected), len(actual))
return
for k, result := range results {
v, ok := indices[k]
if !ok {
t.Fatalf("Failed to found index %v from parsed indices %+v", k, indices)
}
for i, ei := range expected {
t.Run(ei.Name, func(t *testing.T) {
ai := actual[i]
tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option")
for _, name := range []string{"Name", "Class", "Type", "Where", "Comment", "Option"} {
if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() {
t.Errorf(
"index %v %v should equal, expects %v, got %v",
k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(),
)
}
}
if len(ei.Fields) != len(ai.Fields) {
t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields))
return
for idx, ef := range result.Fields {
rf := v.Fields[idx]
if rf.Field.Name != ef.Field.Name {
t.Fatalf("index field should equal, expects %v, got %v", rf.Field.Name, ef.Field.Name)
}
if rf.Field.Unique != ef.Field.Unique {
t.Fatalf("index field '%s' should equal, expects %v, got %v", rf.Field.Name, rf.Field.Unique, ef.Field.Unique)
}
for _, name := range []string{"Expression", "Sort", "Collate", "Length"} {
if reflect.ValueOf(ef).FieldByName(name).Interface() != reflect.ValueOf(rf).FieldByName(name).Interface() {
t.Errorf(
"index %v field #%v's %v should equal, expects %v, got %v", k, idx+1, name,
reflect.ValueOf(ef).FieldByName(name).Interface(), reflect.ValueOf(rf).FieldByName(name).Interface(),
)
}
}
for i, ef := range ei.Fields {
af := ai.Fields[i]
tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull")
}
})
}
}

View File

@ -4,12 +4,6 @@ import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string

View File

@ -8,8 +8,6 @@ import (
"unicode/utf8"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// Namer namer interface
@ -21,7 +19,6 @@ type Namer interface {
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
@ -29,8 +26,6 @@ type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
@ -90,11 +85,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
@ -123,7 +113,7 @@ var (
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism))
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
@ -188,9 +178,9 @@ func (ns NamingStrategy) toDBName(name string) string {
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "")
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}

View File

@ -5,12 +5,8 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"github.com/jinzhu/inflection"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gorm.io/gorm/clause"
)
@ -33,8 +29,6 @@ type Relationships struct {
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
Mux sync.RWMutex
}
type Relationship struct {
@ -78,12 +72,12 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
schema.err = err
return nil
}
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
@ -95,16 +89,14 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
}
}
if relation.Type == has {
// don't add relations to embedded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Mux.Lock()
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
relation.FieldSchema.Relationships.Mux.Unlock()
}
switch field.IndirectFieldType.Kind() {
@ -132,20 +124,6 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
@ -157,12 +135,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
@ -191,41 +169,23 @@ func (schema *Schema) setRelation(relation *Relationship) {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
@ -237,14 +197,12 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
return
}
@ -308,9 +266,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
for idx, ownField := range ownForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name
joinFieldName := strings.Title(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx])
joinFieldName = strings.Title(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = ownField
@ -325,7 +283,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
for idx, relField := range refForeignFields {
joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
@ -336,7 +294,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
if len(joinReferences) > idx {
joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx])
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
@ -354,13 +312,12 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name,
Name: strings.Title(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
@ -479,8 +436,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
}
}
@ -536,9 +492,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
@ -612,7 +566,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
@ -624,31 +577,6 @@ type Constraint struct {
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
@ -663,7 +591,6 @@ func (rel *Relationship) ParseConstraint() *Constraint {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
break
}
}
@ -676,7 +603,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
var (
name string
idx = strings.IndexByte(str, ',')
idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",")
)
@ -763,9 +690,8 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
}
func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(lowerStr, s) {
if strings.Contains(strings.ToLower(string(str)), s) {
return false
}
}

View File

@ -121,29 +121,6 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) {
})
}
func TestBelongsToWithMixin(t *testing.T) {
type Profile struct {
gorm.Model
Refer string
Name string
}
type ProfileMixin struct {
Profile Profile `gorm:"References:Refer"`
ProfileRefer int
}
type User struct {
gorm.Model
ProfileMixin
}
checkStructRelation(t, &User{}, Relation{
Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile",
References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}},
})
}
func TestHasOneOverrideForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
@ -600,193 +577,6 @@ func TestEmbeddedHas(t *testing.T) {
})
}
func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}
type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}
type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}
s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}
func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
@ -799,10 +589,6 @@ func TestEmbeddedBelongsTo(t *testing.T) {
type NestedAddress struct {
Address
}
type CountryMixin struct {
CountryID int
Country Country
}
type Org struct {
ID int
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
@ -813,7 +599,6 @@ func TestEmbeddedBelongsTo(t *testing.T) {
Address
}
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
CountryMixin
}
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
@ -843,6 +628,8 @@ func TestEmbeddedBelongsTo(t *testing.T) {
},
},
"NestedAddress": {
EmbeddedRelations: map[string]EmbeddedRelations{
"Address": {
Relations: map[string]Relation{
"Country": {
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
@ -852,6 +639,8 @@ func TestEmbeddedBelongsTo(t *testing.T) {
},
},
},
},
},
})
}

View File

@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"go/ast"
"path"
"reflect"
"strings"
"sync"
@ -14,20 +13,6 @@ import (
"gorm.io/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
@ -68,10 +53,9 @@ func (schema Schema) String() string {
}
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
@ -248,7 +232,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter(modelType)
field.setupValuerAndSetter()
}
prioritizedPrimaryField := schema.LookUpField("id")
@ -304,26 +288,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error":
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
}
}
}
@ -345,7 +317,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
@ -377,39 +349,6 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {

View File

@ -163,8 +163,8 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
}
for i := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil)
for _, f := range relation.JoinTable.Fields {
checkSchemaField(t, r.JoinTable, &f, nil)
}
}

View File

@ -19,22 +19,6 @@ func TestParseSchema(t *testing.T) {
checkUserSchema(t, user)
}
func TestParseSchemaWithMap(t *testing.T) {
type User struct {
tests.User
Attrs map[string]string `gorm:"type:Map(String,String);"`
}
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with map, got error %v", err)
}
if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" {
t.Errorf("failed to parse user field Attrs")
}
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
@ -62,8 +46,8 @@ func checkUserSchema(t *testing.T, user *schema.Schema) {
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
}
for i := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true
@ -152,8 +136,8 @@ func TestParseSchemaWithAdvancedDataType(t *testing.T) {
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
}
for i := range fields {
checkSchemaField(t, user, &fields[i], func(f *schema.Field) {
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
f.Readable = true

View File

@ -84,10 +84,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
case string:
bytes = []byte(v)
default:
bytes, err = json.Marshal(v)
if err != nil {
return err
}
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
if len(bytes) > 0 {
@ -129,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
result = time.Unix(reflect.Indirect(rv).Int(), 0)
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
result = time.Unix(reflect.Indirect(rv).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}

View File

@ -71,7 +71,7 @@ func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1)
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {

View File

@ -32,7 +32,6 @@ type Statement struct {
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
ColumnMapping map[string]string // map columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
@ -47,17 +46,14 @@ type Statement struct {
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
Result *result
}
type join struct {
Name string
Alias string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
}
@ -208,21 +204,19 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
writer.WriteString("(NULL)")
}
case interface{ getInstance() *DB }:
cv := v.getInstance()
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = cv.Statement.SQL.String()
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
cv.BindVarTo(&bindvar, subdb.Statement, vv)
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
@ -326,30 +320,27 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
arg, _ = valuer.Value()
}
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
if len(orConds.Exprs) == 1 {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
} else {
conds = append(conds, cs.Expression)
}
if v.Statement == stmt {
cs.Expression = nil
stmt.Statement.Clauses["WHERE"] = cs
}
}
case map[interface{}]interface{}:
for i, j := range v {
@ -363,11 +354,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
sort.Strings(keys)
for _, key := range keys {
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
case map[string]interface{}:
keys := make([]string, 0, len(v))
@ -378,16 +365,12 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
@ -396,10 +379,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: column, Values: values})
conds = append(conds, clause.IN{Column: key, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
}
default:
@ -426,9 +409,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
@ -440,9 +423,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
@ -467,22 +450,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
}
return nil
return conds
}
}
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
}
}
}
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
return conds
}
// Build build sql with clauses names
@ -534,14 +513,12 @@ func (stmt *Statement) clone() *Statement {
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
ColumnMapping: stmt.ColumnMapping,
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
}
if stmt.SQL.Len() > 0 {
@ -688,21 +665,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
@ -723,13 +686,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
if matches[2] == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
results[matches[2]] = result
}
} else {
results[column] = result

View File

@ -56,15 +56,9 @@ func TestNameMatcher(t *testing.T) {
"`name_1`": {"", "name_1"},
"`Name_1`": {"", "Name_1"},
"`Table`.`nAme`": {"Table", "nAme"},
"my_table.*": {"my_table", "*"},
"`my_table`.*": {"my_table", "*"},
"User__Company.*": {"User__Company", "*"},
"`User__Company`.*": {"User__Company", "*"},
`"User__Company".*`: {"User__Company", "*"},
`"table"."*"`: {"", ""},
} {
if table, column := matchName(k); table != v[0] || column != v[1] {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v)
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
}
}
}

View File

@ -278,6 +278,8 @@ func TestBelongsToAssociationUnscoped(t *testing.T) {
t.Fatalf("failed to create items, got error: %v", err)
}
tx = tx.Debug()
// test replace
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
Logo: "updated logo",

View File

@ -422,7 +422,7 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
users := []User{
*GetUser("slice-hasmany-1", Config{Toys: 2}),
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
*GetUser("slice-hasmany-2", Config{Toys: 0}),
*GetUser("slice-hasmany-3", Config{Toys: 4}),
}
@ -430,7 +430,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
// Count
AssertAssociationCount(t, users, "Toys", 6, "")
AssertAssociationCount(t, users, "Tools", 2, "")
// Find
var toys []Toy
@ -438,14 +437,6 @@ func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
}
// Find Tools (polymorphic with custom type and id)
var tools []Tools
DB.Model(&users).Association("Tools").Find(&tools)
if len(tools) != 2 {
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
}
// Append
DB.Model(&users).Association("Toys").Append(
&Toy{Name: "toy-slice-append-1"},
@ -554,15 +545,3 @@ func TestHasManyAssociationUnscoped(t *testing.T) {
t.Errorf("expected %d contents, got %d", 0, len(contents))
}
}
func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -255,15 +255,3 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) {
DB.Model(&pets).Association("Toy").Clear()
AssertAssociationCount(t, pets, "Toy", 0, "After Clear")
}
func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) {
user := User{Name: "jinzhu", Account: Account{Number: "1"}}
if err := DB.Create(&user).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil {
t.Error("expected association error to be not nil")
}
}

View File

@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) {
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}},
results: []string{"c1", "c3", "c4", "c5"},
results: []string{"c1", "c5", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}},
@ -206,49 +206,3 @@ func TestPluginCallbacks(t *testing.T) {
t.Errorf("callbacks tests failed, got %v", msg)
}
}
func TestCallbacksGet(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) {
t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1)
}
createCallback.Remove("c1")
if cb := createCallback.Get("c2"); cb != nil {
t.Errorf("callbacks test failed. got: %p, want: nil", cb)
}
}
func TestCallbacksRemove(t *testing.T) {
db, _ := gorm.Open(nil, nil)
createCallback := db.Callback().Create()
createCallback.Before("*").Register("c1", c1)
createCallback.After("*").Register("c2", c2)
createCallback.Before("c4").Register("c3", c3)
createCallback.After("c2").Register("c4", c4)
// callbacks: []string{"c1", "c3", "c4", "c2"}
createCallback.Remove("c1")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c4")
if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c2")
if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
createCallback.Remove("c3")
if ok, msg := assertCallbacks(createCallback, []string{}); !ok {
t.Errorf("callbacks tests failed, got %v", msg)
}
}

View File

@ -1,88 +0,0 @@
package tests_test
import (
"fmt"
"strings"
"testing"
"gorm.io/gorm"
)
type Man struct {
ID int
Age int
Name string
Detail string
}
// Panic-safe BeforeUpdate hook that checks for Changed("age")
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
}
}()
if !tx.Statement.Changed("age") {
return nil
}
return nil
}
func (m *Man) update(data interface{}) error {
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
}
func TestBeforeUpdateStatementChanged(t *testing.T) {
DB.AutoMigrate(&Man{})
type TestCase struct {
BaseObjects Man
change interface{}
expectError bool
}
testCases := []TestCase{
{
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
change: struct {
Age int
}{Age: 20},
expectError: false,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
}{Name: "name-only"},
expectError: true,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
Age int
}{Name: "name-only", Age: 20},
expectError: false,
},
}
for _, test := range testCases {
DB.Create(&test.BaseObjects)
// below comment is stored for future reference
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
err := test.BaseObjects.update(test.change)
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
if !test.expectError {
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
}
} else {
if test.expectError {
t.Errorf("expected panic did not occur for input: %+v", test.change)
}
if err != nil {
t.Errorf("unexpected GORM error: %v", err)
}
}
}
}

View File

@ -102,13 +102,13 @@ func TestConnPoolWrapper(t *testing.T) {
expect: []string{
"SELECT VERSION()",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
},
}
@ -119,7 +119,6 @@ func TestConnPoolWrapper(t *testing.T) {
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
db.Logger = DB.Logger
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}

View File

@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) {
}
var count2 int64
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
if count2 != 2 {

View File

@ -2,7 +2,6 @@ package tests_test
import (
"errors"
"fmt"
"regexp"
"testing"
"time"
@ -14,48 +13,31 @@ import (
)
func TestCreate(t *testing.T) {
u1 := *GetUser("create", Config{})
user := *GetUser("create", Config{})
if results := DB.Create(&u1); results.Error != nil {
if results := DB.Create(&user); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
if u1.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}
if u1.CreatedAt.IsZero() {
if user.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}
if u1.UpdatedAt.IsZero() {
if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}
var newUser User
if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, u1)
}
type user struct {
ID int `gorm:"primaryKey;->:false"`
Name string
Age int
}
var u2 user
if results := DB.Create(&u2); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
if u2.ID != 0 {
t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
CheckUser(t, newUser, user)
}
}
@ -598,213 +580,38 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
}
}
func TestCreateOnConflictWithDefaultNull(t *testing.T) {
type OnConflictUser struct {
func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
type OnConfilctUser struct {
ID string
Name string `gorm:"default:null"`
Email string
Mobile string `gorm:"default:'133xxxx'"`
}
err := DB.Migrator().DropTable(&OnConflictUser{})
err := DB.Migrator().DropTable(&OnConfilctUser{})
AssertEqual(t, err, nil)
err = DB.AutoMigrate(&OnConflictUser{})
err = DB.AutoMigrate(&OnConfilctUser{})
AssertEqual(t, err, nil)
u := OnConflictUser{
ID: "on-conflict-user-id",
Name: "on-conflict-user-name",
Email: "on-conflict-user-email",
Mobile: "on-conflict-user-mobile",
u := OnConfilctUser{
ID: "on-confilct-user-id",
Name: "on-confilct-user-name",
Email: "on-confilct-user-email",
Mobile: "on-confilct-user-mobile",
}
err = DB.Create(&u).Error
AssertEqual(t, err, nil)
u.Name = "on-conflict-user-name-2"
u.Email = "on-conflict-user-email-2"
u.Name = "on-confilct-user-name-2"
u.Email = "on-confilct-user-email-2"
u.Mobile = ""
err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error
AssertEqual(t, err, nil)
var u2 OnConflictUser
var u2 OnConfilctUser
err = DB.Where("id = ?", u.ID).First(&u2).Error
AssertEqual(t, err, nil)
AssertEqual(t, u2.Name, "on-conflict-user-name-2")
AssertEqual(t, u2.Email, "on-conflict-user-email-2")
AssertEqual(t, u2.Name, "on-confilct-user-name-2")
AssertEqual(t, u2.Email, "on-confilct-user-email-2")
AssertEqual(t, u2.Mobile, "133xxxx")
}
func TestCreateFromMapWithoutPK(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supporting for mysql")
}
// case 1: one record, create from map[string]interface{}
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue1["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result1 User
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
var idVal int64
_, ok := mapValue1["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue1["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result1.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case2: one record, create from *map[string]interface{}
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue2["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result2 User
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
_, ok = mapValue2["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue2["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result2.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: records
values := []map[string]interface{}{
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
}
beforeLen := len(values)
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
// mariadb with returning, values will be appended with id map
if len(values) == beforeLen*2 {
t.Skipf("This test case skipped, because the db supports returning")
}
for i := range values {
v, ok := values[i]["id"]
if !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result User
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(result.ID) != v.(int64) {
t.Fatal("failed to create data from map with table, @id != id")
}
}
}
func TestCreateFromMapWithTable(t *testing.T) {
tableDB := DB.Table("users")
supportLastInsertID := isMysql() || isSqlite()
// case 1: create from map[string]interface{}
record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err)
}
if _, ok := record["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
}
// case 2: create from *map[string]interface{}
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
tableDB2 := DB.Table("users")
if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := record1["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res1 map[string]interface{}
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: create from []map[string]interface{}
records := []map[string]interface{}{
{"name": "create_from_map_with_table_2", "age": 19},
{"name": "create_from_map_with_table_3", "age": 20},
}
tableDB = DB.Table("users")
if err := tableDB.Create(&records).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res2 map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
var res3 map[string]interface{}
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
}
if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
t.Errorf("failed to create data from map with table, @id != id")
}
}

View File

@ -38,22 +38,4 @@ func TestDefaultValue(t *testing.T) {
} else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to find created data with default data, got %+v", result)
}
type Harumph2 struct {
ID int `gorm:"default:0"`
Email string `gorm:"not null;index:,unique"`
Name string `gorm:"notNull;default:foo"`
Name2 string `gorm:"size:233;not null;default:'foo'"`
Name3 string `gorm:"size:233;notNull;default:''"`
Age int `gorm:"default:18"`
Created time.Time `gorm:"default:2000-01-02"`
Enabled bool `gorm:"default:true"`
}
harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"}
if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil {
t.Fatalf("Failed to create data with default value, got error: %v", err)
} else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" {
t.Fatalf("Failed to create data with default value, got: %+v", harumph2)
}
}

View File

@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
}
}
// only sqlite, postgres, gaussdb, sqlserver support returning
// only sqlite, postgres support returning
func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
}
func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" {
return
}

View File

@ -1,8 +1,10 @@
version: '3'
services:
mysql:
image: 'mysql:latest'
image: 'mysql/mysql-server:latest'
ports:
- "127.0.0.1:9910:3306"
- "9910:3306"
environment:
- MYSQL_DATABASE=gorm
- MYSQL_USER=gorm
@ -11,22 +13,24 @@ services:
postgres:
image: 'postgres:latest'
ports:
- "127.0.0.1:9920:5432"
- "9920:5432"
environment:
- TZ=Asia/Shanghai
- POSTGRES_DB=gorm
- POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm
mssql:
image: '${MSSQL_IMAGE}:latest'
image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest'
ports:
- "127.0.0.1:9930:1433"
- "9930:1433"
environment:
- TZ=Asia/Shanghai
- ACCEPT_EULA=Y
- MSSQL_SA_PASSWORD=LoremIpsum86
- SA_PASSWORD=LoremIpsum86
- MSSQL_DB=gorm
- MSSQL_USER=gorm
- MSSQL_PASSWORD=LoremIpsum86
tidb:
image: 'pingcap/tidb:v6.5.0'
ports:
- "127.0.0.1:9940:4000"
- "9940:4000"
command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 &

View File

@ -236,15 +236,8 @@ func TestEmbeddedScanValuer(t *testing.T) {
}
func TestEmbeddedRelations(t *testing.T) {
type EmbUser struct {
gorm.Model
Name string
Age uint
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
}
type AdvancedUser struct {
EmbUser `gorm:"embedded"`
User `gorm:"embedded"`
Advanced bool
}
@ -279,6 +272,6 @@ func TestEmbeddedTagSetting(t *testing.T) {
err = DB.Save(&t1).Error
AssertEqual(t, err, nil)
if t1.Tag1.Id == 0 {
t.Errorf("embedded struct's primary field should be rewritten")
t.Errorf("embedded struct's primary field should be rewrited")
}
}

View File

@ -39,7 +39,7 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
@ -81,7 +81,7 @@ func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}

View File

@ -1,248 +0,0 @@
package tests_test
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
func TestGaussDBReturningIDWhichHasStringType(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Yasuo struct {
// TODO: function gen_random_uuid() does not exist
ID string `gorm:"default:gen_random_uuid()"`
Name string
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Yasuo{})
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
yasuo := Yasuo{Name: "jinzhu"}
if err := DB.Create(&yasuo).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
if yasuo.ID == "" {
t.Fatal("should be able to has ID, but got zero value")
}
var result Yasuo
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
yasuo.Name = "jinzhu1"
if err := DB.Save(&yasuo).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
}
func TestGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Harumph struct {
gorm.Model
Name string `gorm:"check:name_checker,name <> ''"`
// TODO: function gen_random_uuid() does not exist
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
Things pq.StringArray `gorm:"type:text[]"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Harumph{})
if err := DB.AutoMigrate(&Harumph{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
harumph := Harumph{}
if err := DB.Create(&harumph).Error; err == nil {
t.Fatalf("should failed to create data, name can't be blank")
}
harumph = Harumph{Name: "jinzhu"}
if err := DB.Create(&harumph).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
var result Harumph
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
harumph.Name = "jinzhu1"
if err := DB.Save(&harumph).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
DB.Migrator().DropTable("log_usage")
if err := DB.Exec(`
CREATE TABLE public.log_usage (
log_id bigint NOT NULL
);
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
SEQUENCE NAME public.log_usage_log_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1
);
`).Error; err != nil {
t.Fatalf("failed to create table, got error %v", err)
}
columns, err := DB.Migrator().ColumnTypes("log_usage")
if err != nil {
t.Fatalf("failed to get columns, got error %v", err)
}
hasLogID := false
for _, column := range columns {
if column.Name() == "log_id" {
hasLogID = true
autoIncrement, ok := column.AutoIncrement()
if !ok || !autoIncrement {
t.Fatalf("column log_id should be auto incrementment")
}
}
}
if !hasLogID {
t.Fatalf("failed to found column log_id")
}
}
func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
}
DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
DB.AutoMigrate(&Post{}, &Category{})
post := Post{
Title: "Hello World",
Categories: []*Category{
{Title: "Coding"},
{Title: "Golang"},
},
}
if err := DB.Create(&post).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
}
func TestGaussDBOnConstraint(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Thing struct {
gorm.Model
SomeID string
OtherID string
Data string
}
DB.Migrator().DropTable(&Thing{})
DB.Migrator().CreateTable(&Thing{})
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
t.Error(err)
}
thing := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something",
}
DB.Create(&thing)
thing2 := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something else",
}
result := DB.Clauses(clause.OnConflict{
OnConstraint: "some_id_other_id_unique",
UpdateAll: true,
}).Create(&thing2)
if result.Error != nil {
t.Errorf("creating second thing: %v", result.Error)
}
var things []Thing
if err := DB.Find(&things).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
if len(things) > 1 {
t.Errorf("expected 1 thing got more")
}
}
func TestGaussDBAlterColumnDataType(t *testing.T) {
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
DB.Migrator().DropTable(&Company{})
DB.AutoMigrate(Company{})
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
t.Fatalf("failed to alter column from string to int, got error %v", err)
}
DB.AutoMigrate(Company{})
}

View File

@ -1,875 +0,0 @@
package tests_test
import (
"context"
"errors"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"testing"
"github.com/google/uuid"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
func TestGenericsCreate(t *testing.T) {
ctx := context.Background()
user := User{Name: "TestGenericsCreate", Age: 18}
err := gorm.G[User](DB).Create(ctx, &user)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if user.ID == 0 {
t.Fatalf("no primary key found for %v", user)
}
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.ID != user.ID {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.ID != user.ID {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.Age != 0 {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != "" || u.Age != user.Age {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
result := struct {
ID int
Name string
}{}
if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
t.Fatalf("failed to scan user, got error: %v", err)
} else if result.Name != user.Name || uint(result.ID) != user.ID {
t.Errorf("found invalid user, got %v, expect %v", result, user)
}
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
}
}
func TestGenericsCreateInBatches(t *testing.T) {
batch := []User{
{Name: "GenericsCreateInBatches1"},
{Name: "GenericsCreateInBatches2"},
{Name: "GenericsCreateInBatches3"},
}
ctx := context.Background()
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
for _, u := range batch {
if u.ID == 0 {
t.Fatalf("no primary key found for %v", u)
}
}
count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 3 {
t.Errorf("expected 3 records, got %d", count)
}
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
if len(found) != len(batch) {
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
}
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
if len(found) != 2 {
t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
}
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
if len(found) != 1 {
t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
}
}
func TestGenericsExecAndUpdate(t *testing.T) {
ctx := context.Background()
name := "GenericsExec"
if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
t.Fatalf("Exec insert failed: %v", err)
}
u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != name || u.ID == 0 {
t.Errorf("found invalid user, got %v", u)
}
name += "Update"
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
if rows != 1 {
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
}
nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if nu.Name != name || u.ID != nu.ID {
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
}
rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
if rows != 1 {
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
}
nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
}
}
func TestGenericsRow(t *testing.T) {
ctx := context.Background()
user := User{Name: "GenericsRow"}
if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
t.Fatalf("Create failed: %v", err)
}
row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
var name string
if err := row.Scan(&name); err != nil {
t.Fatalf("Row scan failed: %v", err)
}
if name != user.Name {
t.Errorf("expected %s, got %s", user.Name, name)
}
user2 := User{Name: "GenericsRow2"}
if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
t.Fatalf("Create failed: %v", err)
}
rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
if err != nil {
t.Fatalf("Rows failed: %v", err)
}
count := 0
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
t.Fatalf("rows.Scan failed: %v", err)
}
count++
}
if count != 2 {
t.Errorf("expected 2 rows, got %d", count)
}
}
func TestGenericsDelete(t *testing.T) {
ctx := context.Background()
u := User{Name: "GenericsDelete"}
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
t.Fatalf("Create failed: %v", err)
}
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if rows != 1 {
t.Errorf("expected 1 row deleted, got %d", rows)
}
_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
if err != gorm.ErrRecordNotFound {
t.Fatalf("User after delete failed: %v", err)
}
}
func TestGenericsFindInBatches(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: "GenericsFindBatchA"},
{Name: "GenericsFindBatchB"},
{Name: "GenericsFindBatchC"},
{Name: "GenericsFindBatchD"},
{Name: "GenericsFindBatchE"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
total := 0
err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
if len(chunk) > 2 {
t.Errorf("batch size exceed 2: got %d", len(chunk))
}
total += len(chunk)
return nil
})
if err != nil {
t.Fatalf("FindInBatches failed: %v", err)
}
if total != len(users) {
t.Errorf("expected total %d, got %d", len(users), total)
}
}
func TestGenericsScopes(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
if err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
filterName1 := func(stmt *gorm.Statement) {
stmt.Where("name = ?", "GenericsScopes1")
}
results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
if err != nil {
t.Fatalf("Scopes failed: %v", err)
}
if len(results) != 1 || results[0].Name != "GenericsScopes1" {
t.Fatalf("Scopes expected 1, got %d", len(results))
}
notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
if len(notResult) != 2 {
t.Fatalf("expected 2 results, got %d", len(notResult))
} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
}
orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
if len(orResult) != 2 {
t.Fatalf("expected 2 results, got %d", len(notResult))
} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
}
}
func TestGenericsJoins(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
// Inner JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Inner JOIN + WHERE with map
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
db.Where(map[string]any{"name": u.Company.Name})
return nil
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN w/o WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN + Alias WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE + Select
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
return errors.New("join error")
}).First(ctx)
if err == nil {
t.Fatalf("Joins should got error, but got nil")
}
}
func TestGenericsNestedJoins(t *testing.T) {
users := []User{
{
Name: "generics-nested-joins-1",
Manager: &User{
Name: "generics-nested-joins-manager-1",
Company: Company{
Name: "generics-nested-joins-manager-company-1",
},
NamedPet: &Pet{
Name: "generics-nested-joins-manager-namepet-1",
Toy: Toy{
Name: "generics-nested-joins-manager-namepet-toy-1",
},
},
},
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
},
{
Name: "generics-nested-joins-2",
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
},
}
ctx := context.Background()
db := gorm.G[User](DB)
db.CreateInBatches(ctx, &users, 100)
var userIDs []uint
for _, user := range users {
userIDs = append(userIDs, user.ID)
}
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
Where(map[string]any{"id": userIDs}).Find(ctx)
if err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
} else if len(users2) != len(users) {
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
}
sort.Slice(users2, func(i, j int) bool {
return users2[i].ID > users2[j].ID
})
sort.Slice(users, func(i, j int) bool {
return users[i].ID > users[j].ID
})
for idx, user := range users {
// user
CheckUser(t, user, users2[idx])
if users2[idx].Manager == nil {
t.Fatalf("Failed to load Manager")
}
// manager
CheckUser(t, *user.Manager, *users2[idx].Manager)
// user pet
if users2[idx].NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
// manager pet
if users2[idx].Manager.NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
}
}
func TestGenericsPreloads(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
names := []string{u.Name, u2.Name, u3.Name}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
}
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
db.Where("name = ?", u.Company.Name)
return nil
}).Where("name in ?", names).Find(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
for _, result := range results {
if result.Name == u.Name {
if result.Company.Name != u.Company.Name {
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
} else if result.Company.Name != "" {
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
}
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
return errors.New("preload error")
}).Where("name in ?", names).Find(ctx)
if err == nil {
t.Fatalf("Preload should failed, but got nil")
}
if DB.Dialector.Name() == "mysql" {
// mysql 5.7 doesn't support row_number()
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
return
}
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name desc").LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name < result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name").LimitPerRecord(5)
return nil
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
db.Order("name")
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
if len(result.Friends) != len(u.Friends) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
}
func TestGenericsNestedPreloads(t *testing.T) {
user := *GetUser("generics_nested_preload", Config{Pets: 2})
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
ctx := context.Background()
db := gorm.G[User](DB)
for idx, pet := range user.Pets {
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
}
if err := db.Create(ctx, &user); err != nil {
t.Fatalf("errors happened when create: %v", err)
}
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user2, user)
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
t.Fatalf("failed to nested preload")
}
if DB.Dialector.Name() == "mysql" {
// mysql 5.7 doesn't support row_number()
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
return
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(3)
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user3, user)
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
t.Errorf("failed to nested preload with limit per record")
}
}
func TestGenericsDistinct(t *testing.T) {
ctx := context.Background()
batch := []User{
{Name: "GenericsDistinctDup"},
{Name: "GenericsDistinctDup"},
{Name: "GenericsDistinctUnique"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
if err != nil {
t.Fatalf("Distinct Find failed: %v", err)
}
if len(results) != 2 {
t.Errorf("expected 2 distinct names, got %d", len(results))
}
var names []string
for _, u := range results {
names = append(names, u.Name)
}
sort.Strings(names)
expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
if !reflect.DeepEqual(names, expected) {
t.Errorf("expected names %v, got %v", expected, names)
}
}
func TestGenericsGroupHaving(t *testing.T) {
ctx := context.Background()
batch := []User{
{Name: "GenericsGroupHavingMulti"},
{Name: "GenericsGroupHavingMulti"},
{Name: "GenericsGroupHavingSingle"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
if err != nil {
t.Fatalf("Group+Having Find failed: %v", err)
}
if len(grouped) != 1 {
t.Errorf("expected 1 group with count>1, got %d", len(grouped))
} else if grouped[0].Name != "GenericsGroupHavingMulti" {
t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
}
}
func TestGenericsSubQuery(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: "GenericsSubquery_1", Age: 10},
{Name: "GenericsSubquery_2", Age: 20},
{Name: "GenericsSubquery_3", Age: 30},
{Name: "GenericsSubquery_4", Age: 40},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
if err != nil {
t.Fatalf("got error: %v", err)
}
if len(results) != 4 {
t.Errorf("Four users should be found, instead found %d", len(results))
}
results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
if err != nil {
t.Fatalf("got error: %v", err)
}
if len(results) != 3 {
t.Errorf("Three users should be found, instead found %d", len(results))
}
}
func TestGenericsUpsert(t *testing.T) {
ctx := context.Background()
lang := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
lang2 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
if err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
}
lang3 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{
Columns: []clause.Column{{Name: "code"}},
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
}).Create(ctx, &lang3); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
} else if langs[0].Name != "upsert-new" {
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
}
}
func TestGenericsWithResult(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
result := gorm.WithResult()
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
if err != nil {
t.Errorf("failed to create users WithResult")
}
if result.RowsAffected != 2 {
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
}
}
func TestGenericsReuse(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
if err != nil {
t.Errorf("failed to create users")
}
reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
sg := sync.WaitGroup{}
for i := 0; i < 5; i++ {
sg.Add(1)
go func() {
if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
}
if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
}
if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if len(users) != 2 {
t.Errorf("should find 2 users, but got %d", len(users))
}
sg.Done()
}()
}
sg.Wait()
}
func TestGenericsWithTransaction(t *testing.T) {
ctx := context.Background()
tx := DB.Begin()
if tx.Error != nil {
t.Fatalf("failed to begin transaction: %v", tx.Error)
}
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 2 {
t.Errorf("expected 2 records, got %d", count)
}
if err := tx.Rollback().Error; err != nil {
t.Fatalf("failed to rollback transaction: %v", err)
}
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count2 != 0 {
t.Errorf("expected 0 records after rollback, got %d", count2)
}
}
func TestGenericsToSQL(t *testing.T) {
ctx := context.Background()
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
gorm.G[User](tx).Limit(10).Find(ctx)
return tx
})
if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
}
}
func TestGenericsScanUUID(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: uuid.NewString(), Age: 21},
{Name: uuid.NewString(), Age: 22},
{Name: uuid.NewString(), Age: 23},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
userIds := []uuid.UUID{}
if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 {
t.Fatalf("Scan failed: %v, userids %v", err, userIds)
}
if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name {
t.Fatalf("wrong uuid scanned")
}
}

View File

@ -1,40 +1,30 @@
module gorm.io/gorm/tests
go 1.23.0
go 1.18
require (
github.com/google/uuid v1.6.0
github.com/google/uuid v1.3.1
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.10.0
gorm.io/driver/gaussdb v0.1.0
gorm.io/driver/mysql v1.6.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/driver/sqlserver v1.6.1
gorm.io/gorm v1.30.0
gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0
gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196
gorm.io/driver/sqlite v1.5.3
gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a
gorm.io/gorm v1.25.4
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.5 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/microsoft/go-mssqldb v1.9.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/text v0.27.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/microsoft/go-mssqldb v1.5.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/text v0.12.0 // indirect
)
replace gorm.io/gorm => ../

View File

@ -23,7 +23,6 @@ type Config struct {
Languages int
Friends int
NamedPet bool
Tools int
}
func GetUser(name string, config Config) *User {
@ -48,10 +47,6 @@ func GetUser(name string, config Config) *User {
user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)})
}
for i := 0; i < config.Tools; i++ {
user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)})
}
if config.Company {
user.Company = Company{Name: "company-" + name}
}
@ -123,13 +118,11 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday",
"CompanyID", "ManagerID", "Active")
AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
}
}
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID",
"ManagerID", "Active")
AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
t.Run("Account", func(t *testing.T) {
AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
@ -140,8 +133,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
} else {
var account Account
db(unscoped).First(&account, "user_id = ?", user.ID)
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID",
"Number")
AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number")
}
}
})
@ -201,10 +193,8 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
} else {
var manager User
db(unscoped).First(&manager, "id = ?", *user.ManagerID)
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
"Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
"Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
}
} else if user.ManagerID != nil {
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
@ -225,8 +215,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
})
for idx, team := range user.Team {
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
"Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
}
})
@ -261,8 +250,7 @@ func doCheckUser(t *testing.T, user User, expect User, unscoped bool) {
})
for idx, friend := range user.Friends {
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age",
"Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
}
})
}
@ -281,10 +269,6 @@ func isMysql() bool {
return os.Getenv("GORM_DIALECT") == "mysql"
}
func isSqlite() bool {
return os.Getenv("GORM_DIALECT") == "sqlite"
}
func db(unscoped bool) *gorm.DB {
if unscoped {
return DB.Unscoped()

View File

@ -2,8 +2,6 @@ package tests_test
import (
"errors"
"log"
"os"
"reflect"
"strings"
"testing"
@ -568,44 +566,3 @@ func TestUpdateCallbacks(t *testing.T) {
t.Fatalf("before update should not be called")
}
}
type Product6 struct {
gorm.Model
Name string
Item *ProductItem2
}
type ProductItem2 struct {
gorm.Model
Product6ID uint
}
func (p *Product6) BeforeDelete(tx *gorm.DB) error {
if err := tx.Delete(&p.Item).Error; err != nil {
return err
}
return nil
}
func TestPropagateUnscoped(t *testing.T) {
_DB, err := OpenTestConnection(&gorm.Config{
PropagateUnscoped: true,
})
if err != nil {
log.Printf("failed to connect database, got error %v", err)
os.Exit(1)
}
_DB.Migrator().DropTable(&Product6{}, &ProductItem2{})
_DB.AutoMigrate(&Product6{}, &ProductItem2{})
p := Product6{
Name: "unique_code",
Item: &ProductItem2{},
}
_DB.Model(&Product6{}).Create(&p)
if err := _DB.Unscoped().Delete(&p).Error; err != nil {
t.Fatalf("unscoped did not propagate")
}
}

View File

@ -1,12 +1,10 @@
package tests_test
import (
"fmt"
"regexp"
"sort"
"testing"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)
@ -186,12 +184,14 @@ func TestJoinCount(t *testing.T) {
DB.Create(&user)
query := DB.Model(&User{}).Joins("Company")
// Bug happens when .Count is called on a query.
// Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass.
var total int64
query.Count(&total)
var result User
// Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id
if err := query.First(&result, user.ID).Error; err != nil {
t.Fatalf("Failed, got error: %v", err)
}
@ -199,10 +199,6 @@ func TestJoinCount(t *testing.T) {
if result.ID != user.ID {
t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID)
}
// should find company
if result.Company.ID != *user.CompanyID {
t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID)
}
}
func TestJoinWithSoftDeleted(t *testing.T) {
@ -404,75 +400,3 @@ func TestNestedJoins(t *testing.T) {
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
}
}
func TestJoinsPreload_Issue7013(t *testing.T) {
manager := &User{Name: "Manager"}
DB.Create(manager)
var userIDs []uint
for i := 0; i < 21; i++ {
user := &User{Name: fmt.Sprintf("User%d", i), ManagerID: &manager.ID}
DB.Create(user)
userIDs = append(userIDs, user.ID)
}
var entries []User
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Preload("Manager.Team").
Joins("Manager.Company").
Find(&entries).Error)
})
}
func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) {
type (
Furniture struct {
gorm.Model
OwnerID *uint
}
Owner struct {
gorm.Model
Furnitures []Furniture
CompanyID *uint
Company Company
}
Building struct {
gorm.Model
Name string
OwnerID *uint
Owner Owner
}
)
DB.Migrator().DropTable(&Building{}, &Owner{}, &Furniture{})
DB.Migrator().AutoMigrate(&Building{}, &Owner{}, &Furniture{})
home := &Building{Name: "relation_empty"}
DB.Create(home)
var entries []Building
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Preload("Owner.Furnitures").
Joins("Owner.Company").
Find(&entries).Error)
})
AssertEqual(t, entries, []Building{{Model: home.Model, Name: "relation_empty", Owner: Owner{Company: Company{}}}})
}
func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
var entries []User
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Preload("Manager.Team").
Joins("Manager.Company").
Where("1 <> 1").
Find(&entries).Error)
})
AssertEqual(t, len(entries), 0)
}

View File

@ -1,529 +0,0 @@
package tests_test
import (
"crypto/rand"
"fmt"
"gorm.io/gorm/internal/lru"
"math"
"math/big"
"reflect"
"sync"
"testing"
"time"
)
func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) {
lru := lru.NewLRU[string, int](10, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key1", 2)
if value, ok := lru.Get("key1"); !ok || value != 2 {
t.Errorf("Expected value to be updated to 2, got %v", value)
}
}
func TestLRU_Add_NewKey_AddsEntry(t *testing.T) {
lru := lru.NewLRU[string, int](10, nil, time.Hour)
lru.Add("key1", 1)
if value, ok := lru.Get("key1"); !ok || value != 1 {
t.Errorf("Expected key1 to be added with value 1, got %v", value)
}
}
func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) {
lru := lru.NewLRU[string, int](2, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
if _, ok := lru.Get("key1"); ok {
t.Errorf("Expected key1 to be removed, but it still exists")
}
}
func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) {
lru := lru.NewLRU[string, int](0, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
if _, ok := lru.Get("key1"); !ok {
t.Errorf("Expected key1 to exist, but it was evicted")
}
}
func TestLRU_Add_Eviction(t *testing.T) {
lru := lru.NewLRU[string, int](0, nil, time.Second*2)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
time.Sleep(time.Second * 3)
if lru.Cap() != 0 {
t.Errorf("Expected lru to be empty, but it was not")
}
}
func BenchmarkLRU_Rand_NoExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, 0)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
trace[i] = getRand(b) % 32768
}
b.ResetTimer()
var hit, miss int
for i := 0; i < 2*b.N; i++ {
if i%2 == 0 {
l.Add(trace[i], trace[i])
} else {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Freq_NoExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, 0)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
if i%2 == 0 {
trace[i] = getRand(b) % 16384
} else {
trace[i] = getRand(b) % 32768
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Add(trace[i], trace[i])
}
var hit, miss int
for i := 0; i < b.N; i++ {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Rand_WithExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
trace[i] = getRand(b) % 32768
}
b.ResetTimer()
var hit, miss int
for i := 0; i < 2*b.N; i++ {
if i%2 == 0 {
l.Add(trace[i], trace[i])
} else {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Freq_WithExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
if i%2 == 0 {
trace[i] = getRand(b) % 16384
} else {
trace[i] = getRand(b) % 32768
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Add(trace[i], trace[i])
}
var hit, miss int
for i := 0; i < b.N; i++ {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func TestLRUNoPurge(t *testing.T) {
lc := lru.NewLRU[string, string](10, nil, 0)
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok := lc.Peek("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if !lc.Contains("key1") {
t.Fatalf("should contain key1")
}
if lc.Contains("key2") {
t.Fatalf("should not contain key2")
}
v, ok = lc.Peek("key2")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
t.Fatalf("value differs from expected")
}
if lc.Resize(0) != 0 {
t.Fatalf("evicted count differs from expected")
}
if lc.Resize(2) != 0 {
t.Fatalf("evicted count differs from expected")
}
lc.Add("key2", "val2")
if lc.Resize(1) != 1 {
t.Fatalf("evicted count differs from expected")
}
}
func TestLRUEdgeCases(t *testing.T) {
lc := lru.NewLRU[string, *string](2, nil, 0)
// Adding a nil value
lc.Add("key1", nil)
value, exists := lc.Get("key1")
if value != nil || !exists {
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
}
// Adding an entry with the same key but different value
newVal := "val1"
lc.Add("key1", &newVal)
value, exists = lc.Get("key1")
if value != &newVal || !exists {
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
}
}
func TestLRU_Values(t *testing.T) {
lc := lru.NewLRU[string, string](3, nil, 0)
lc.Add("key1", "val1")
lc.Add("key2", "val2")
lc.Add("key3", "val3")
values := lc.Values()
if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) {
t.Fatalf("values differs from expected")
}
}
// func TestExpirableMultipleClose(_ *testing.T) {
// lc :=lru.NewLRU[string, string](10, nil, 0)
// lc.Close()
// // should not panic
// lc.Close()
// }
func TestLRUWithPurge(t *testing.T) {
var evicted []string
lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond)
k, v, ok := lc.GetOldest()
if k != "" {
t.Fatalf("should be empty")
}
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
lc.Add("key1", "val1")
time.Sleep(100 * time.Millisecond) // not enough to expire
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Get("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
time.Sleep(200 * time.Millisecond) // expire
v, ok = lc.Get("key1")
if ok {
t.Fatalf("should be false")
}
if v != "" {
t.Fatalf("should be nil")
}
if lc.Len() != 0 {
t.Fatalf("length differs from expected")
}
if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) {
t.Fatalf("value differs from expected")
}
// add new entry
lc.Add("key2", "val2")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
k, v, ok = lc.GetOldest()
if k != "key2" {
t.Fatalf("value differs from expected")
}
if v != "val2" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
}
func TestLRUWithPurgeEnforcedBySize(t *testing.T) {
lc := lru.NewLRU[string, string](10, nil, time.Hour)
for i := 0; i < 100; i++ {
i := i
lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i))
v, ok := lc.Get(fmt.Sprintf("key%d", i))
if v != fmt.Sprintf("val%d", i) {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if lc.Len() > 20 {
t.Fatalf("length should be less than 20")
}
}
if lc.Len() != 10 {
t.Fatalf("length differs from expected")
}
}
func TestLRUConcurrency(t *testing.T) {
lc := lru.NewLRU[string, string](0, nil, 0)
wg := sync.WaitGroup{}
wg.Add(1000)
for i := 0; i < 1000; i++ {
go func(i int) {
lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10))
wg.Done()
}(i)
}
wg.Wait()
if lc.Len() != 100 {
t.Fatalf("length differs from expected")
}
}
func TestLRUInvalidateAndEvict(t *testing.T) {
var evicted int
lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0)
lc.Add("key1", "val1")
lc.Add("key2", "val2")
val, ok := lc.Get("key1")
if !ok {
t.Fatalf("should be true")
}
if val != "val1" {
t.Fatalf("value differs from expected")
}
if evicted != 0 {
t.Fatalf("value differs from expected")
}
lc.Remove("key1")
if evicted != 1 {
t.Fatalf("value differs from expected")
}
val, ok = lc.Get("key1")
if val != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
}
func TestLoadingExpired(t *testing.T) {
lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5)
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok := lc.Peek("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
v, ok = lc.Get("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
for {
result, ok := lc.Get("key1")
if ok && result == "" {
t.Fatalf("ok should return a result")
}
if !ok {
break
}
}
time.Sleep(time.Millisecond * 100) // wait for expiration reaper
if lc.Len() != 0 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Peek("key1")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
v, ok = lc.Get("key1")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
}
func TestLRURemoveOldest(t *testing.T) {
lc := lru.NewLRU[string, string](2, nil, 0)
if lc.Cap() != 2 {
t.Fatalf("expect cap is 2")
}
k, v, ok := lc.RemoveOldest()
if k != "" {
t.Fatalf("should be empty")
}
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
ok = lc.Remove("non_existent")
if ok {
t.Fatalf("should be false")
}
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Get("key1")
if !ok {
t.Fatalf("should be true")
}
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
lc.Add("key2", "val2")
if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 2 {
t.Fatalf("length differs from expected")
}
k, v, ok = lc.RemoveOldest()
if k != "key1" {
t.Fatalf("value differs from expected")
}
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
}
func getRand(tb testing.TB) int64 {
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
tb.Fatal(err)
}
return out.Int64()
}

View File

@ -2,29 +2,23 @@ package tests_test
import (
"context"
"database/sql"
"fmt"
"math/rand"
"os"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"gorm.io/driver/gaussdb"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
. "gorm.io/gorm/utils/tests"
)
func TestMigrate(t *testing.T) {
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
DB.Migrator().DropTable("user_speaks", "user_friends", "ccc")
@ -40,7 +34,7 @@ func TestMigrate(t *testing.T) {
if tables, err := DB.Migrator().GetTables(); err != nil {
t.Fatalf("Failed to get database all tables, but got error %v", err)
} else {
for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} {
for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages"} {
hasTable := false
for _, t2 := range tables {
if t2 == t1 {
@ -83,8 +77,8 @@ func TestMigrate(t *testing.T) {
}
}
func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
@ -99,8 +93,7 @@ func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s",
sql)
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
}
},
}
@ -140,137 +133,8 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
}
}
func TestAutoMigrateNullable(t *testing.T) {
type MigrateNullableColumn struct {
ID uint
Bonus float64 `gorm:"not null"`
Stock float64
}
DB.Migrator().DropTable(&MigrateNullableColumn{})
DB.AutoMigrate(&MigrateNullableColumn{})
type MigrateNullableColumn2 struct {
ID uint
Bonus float64
Stock float64 `gorm:"not null"`
}
if err := DB.Table("migrate_nullable_columns").AutoMigrate(&MigrateNullableColumn2{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err := DB.Table("migrate_nullable_columns").Migrator().ColumnTypes(&MigrateNullableColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "bonus":
// allow to change non-nullable to nullable
if nullable, _ := columnType.Nullable(); !nullable {
t.Fatalf("bonus's nullable should be true, bug got %t", nullable)
}
case "stock":
// do not allow to change nullable to non-nullable
if nullable, _ := columnType.Nullable(); !nullable {
t.Fatalf("stock's nullable should be true, bug got %t", nullable)
}
}
}
}
func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "postgres": true, "gaussdb": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
Name string
Salary float64
Birthday time.Time `gorm:"precision:4"`
}
DB.Migrator().DropTable(&UserMigrateColumn{})
DB.AutoMigrate(&UserMigrateColumn{})
type UserMigrateColumn2 struct {
ID uint
Name string `gorm:"size:128"`
Salary float64 `gorm:"precision:2"`
Birthday time.Time `gorm:"precision:2"`
NameIgnoreMigration string `gorm:"size:100"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
type UserMigrateColumn3 struct {
ID uint
Name string `gorm:"size:256"`
Salary float64 `gorm:"precision:3"`
Birthday time.Time `gorm:"precision:3"`
NameIgnoreMigration string `gorm:"size:128;-:migration"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
case "name_ignore_migration":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 {
t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length)
}
}
}
}
func TestSmartMigrateColumnGaussDB(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "gaussdb": true}[DB.Dialector.Name()]
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
@ -568,50 +432,40 @@ func TestTiDBMigrateColumns(t *testing.T) {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
if length, ok := columnType.Length(); !ok || length != 100 {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), length, 100, columnType)
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !ok || v != "my age" {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
columnType.Name(), columnType, v)
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !ok || v != "my code2" {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
// Code2 string `gorm:"comment:my code2;default:hello"`
if v, ok := columnType.DefaultValue(); !ok || v != "hello" {
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
columnType.Name(), columnType, v)
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !ok || v != "my code2" {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
}
}
@ -643,8 +497,7 @@ func TestTiDBMigrateColumns(t *testing.T) {
t.Fatalf("Failed to add column, got %v", err)
}
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
"new_new_name"); err != nil {
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
@ -708,45 +561,36 @@ func TestMigrateColumns(t *testing.T) {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "name":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) {
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), length, 100, columnType)
t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
}
case "age":
if v, ok := columnType.DefaultValue(); !ok || v != "18" {
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") {
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code":
if v, ok := columnType.Unique(); !ok || !v {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v",
columnType.Name(), columnType, v)
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code2":
if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) {
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "code3":
// TODO
@ -783,8 +627,7 @@ func TestMigrateColumns(t *testing.T) {
t.Fatalf("Failed to add column, got %v", err)
}
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName",
"new_new_name"); err != nil {
if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil {
t.Fatalf("Failed to add column, got %v", err)
}
@ -938,7 +781,7 @@ func TestMigrateColumnOrder(t *testing.T) {
// https://github.com/go-gorm/gorm/issues/5047
func TestMigrateSerialColumn(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
if DB.Dialector.Name() != "postgres" {
return
}
@ -1019,48 +862,6 @@ func TestMigrateWithSpecialName(t *testing.T) {
AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2"))
}
// https://github.com/go-gorm/gorm/issues/4760
func TestMigrateAutoIncrement(t *testing.T) {
type AutoIncrementStruct struct {
ID int64 `gorm:"primarykey;autoIncrement"`
Field1 uint32 `gorm:"column:field1"`
Field2 float32 `gorm:"column:field2"`
}
if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil {
t.Fatalf("AutoMigrate err: %v", err)
}
const ROWS = 10
for idx := 0; idx < ROWS; idx++ {
if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil {
t.Fatalf("create auto_increment_struct fail, err: %v", err)
}
}
rows := make([]*AutoIncrementStruct, 0, ROWS)
if err := DB.Order("id ASC").Find(&rows).Error; err != nil {
t.Fatalf("find auto_increment_struct fail, err: %v", err)
}
ids := make([]int64, 0, len(rows))
for _, row := range rows {
ids = append(ids, row.ID)
}
lastID := ids[len(ids)-1]
if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil {
t.Fatalf("delete auto_increment_struct fail, err: %v", err)
}
newRow := &AutoIncrementStruct{}
if err := DB.Create(newRow).Error; err != nil {
t.Fatalf("create auto_increment_struct fail, err: %v", err)
}
AssertEqual(t, newRow.ID, lastID+1)
}
// https://github.com/go-gorm/gorm/issues/5320
func TestPrimarykeyID(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
@ -1097,42 +898,6 @@ func TestPrimarykeyID(t *testing.T) {
}
}
func TestPrimarykeyIDGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support uuid-ossp plugin (SQLSTATE 58P01)")
if DB.Dialector.Name() != "gaussdb" {
return
}
type MissPKLanguage struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
Name string
}
type MissPKUser struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
}
var err error
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("DropTable err:%v", err)
}
// TODO: ERROR: could not open extension control file: No such file or directory (SQLSTATE 58P01)
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
// patch
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
}
func TestCurrentTimestamp(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
return
@ -1155,8 +920,7 @@ func TestCurrentTimestamp(t *testing.T) {
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at"))
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at"))
AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2"))
}
@ -1218,8 +982,7 @@ func TestUniqueColumn(t *testing.T) {
}
// not trigger alert column
AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name"))
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name"))
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1"))
AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2"))
@ -1333,24 +1096,35 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
}
}
// TODO: ERROR: must have at least one column (SQLSTATE 0A000)
func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support creaing empty table(SQLSTATE 0A000)")
if DB.Dialector.Name() != "gaussdb" {
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
db, err := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
if err != nil {
t.Errorf("Open err:%v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" {
db.Logger = db.Logger.LogMode(logger.Silent)
}
type Object1 struct{}
type Object1 struct {
ID uint
}
type Object2 struct {
Field1 string
ID uint
Field1 int `gorm:"type:int8"`
}
type Object3 struct {
Field2 string
ID uint
Field1 int `gorm:"type:int4"`
}
type Object4 struct {
ID uint
Field2 int
}
db.Migrator().DropTable("objects")
@ -1358,16 +1132,63 @@ func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Create(&Object1{}).Error
if err != nil {
t.Errorf("create err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object2{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AlterColumn
err = db.Table("objects").AutoMigrate(&Object3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object3{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object4{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
}
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
@ -1410,7 +1231,7 @@ func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
}
func TestMigrateArrayTypeModel(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
if DB.Dialector.Name() != "postgres" {
return
}
@ -1520,14 +1341,14 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) {
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count")
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count")
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
AssertEqual(t, nil, err)
}
func TestMigrateWithDefaultValue(t *testing.T) {
func TestMigrateDefaultNullString(t *testing.T) {
if DB.Dialector.Name() == "sqlserver" {
// sqlserver driver treats NULL and 'NULL' the same
t.Skip("skip sqlserver")
@ -1541,7 +1362,6 @@ func TestMigrateWithDefaultValue(t *testing.T) {
type NullStringModel struct {
ID uint
Content string `gorm:"default:'null'"`
Active bool `gorm:"default:false"`
}
tableName := "null_string_model"
@ -1562,14 +1382,6 @@ func TestMigrateWithDefaultValue(t *testing.T) {
AssertEqual(t, defVal, "null")
AssertEqual(t, ok, true)
columnType2, err := findColumnType(tableName, "active")
AssertEqual(t, err, nil)
defVal, ok = columnType2.DefaultValue()
bv, _ := strconv.ParseBool(defVal)
AssertEqual(t, bv, false)
AssertEqual(t, ok, true)
// default 'null' -> 'null'
session := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
@ -1701,8 +1513,7 @@ func TestMigrateIgnoreRelations(t *testing.T) {
func TestMigrateView(t *testing.T) {
DB.Save(GetUser("joins-args-db", Config{Pets: 2}))
if err := DB.Migrator().CreateView("invalid_users_pets",
gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired {
t.Fatalf("no view should be created, got %v", err)
}
@ -1732,8 +1543,8 @@ func TestMigrateView(t *testing.T) {
}
}
func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
func TestMigrateExistingBoolColumnPG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
@ -1771,20 +1582,17 @@ func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
switch columnType.Name() {
case "id":
if v, ok := columnType.PrimaryKey(); !ok || !v {
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(),
columnType)
t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
case "string_bool":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
case "smallint_bool":
dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v",
columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
}
}
}
@ -1809,8 +1617,7 @@ func TestTableType(t *testing.T) {
DB.Migrator().DropTable(&City{})
if err := DB.Set("gorm:table_options",
fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil {
t.Fatalf("failed to migrate cities tables, got error: %v", err)
}
@ -1836,329 +1643,3 @@ func TestTableType(t *testing.T) {
t.Fatalf("expected comment %s got %s", tblComment, comment)
}
}
func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
const table = "unique_struct"
checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) {
stmt := &gorm.Statement{DB: DB}
err := stmt.Parse(model)
if err != nil {
t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err)
}
_ = stmt.Schema.ParseIndexes()
field := stmt.Schema.LookUpField(fieldName)
if field == nil {
t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName)
}
if field.Unique != unique {
t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique)
}
if field.UniqueIndex != uniqueIndex {
t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex)
}
}
type ( // not unique
UniqueStruct1 struct {
Name string `gorm:"size:10"`
}
UniqueStruct2 struct {
Name string `gorm:"size:20"`
}
)
checkField(&UniqueStruct1{}, "name", false, "")
checkField(&UniqueStruct2{}, "name", false, "")
type ( // unique
UniqueStruct3 struct {
Name string `gorm:"size:30;unique"`
}
UniqueStruct4 struct {
Name string `gorm:"size:40;unique"`
}
)
checkField(&UniqueStruct3{}, "name", true, "")
checkField(&UniqueStruct4{}, "name", true, "")
type ( // uniqueIndex
UniqueStruct5 struct {
Name string `gorm:"size:50;uniqueIndex"`
}
UniqueStruct6 struct {
Name string `gorm:"size:60;uniqueIndex"`
}
UniqueStruct7 struct {
Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"`
}
)
checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name")
checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name")
checkField(&UniqueStruct7{}, "name", false, "")
checkField(&UniqueStruct7{}, "nick_name", false, "")
checkField(&UniqueStruct7{}, "nick_name", false, "")
type UniqueStruct8 struct { // unique and uniqueIndex
Name string `gorm:"size:60;unique;index:my_us8_index,unique;"`
}
checkField(&UniqueStruct8{}, "name", true, "my_us8_index")
type TestCase struct {
name string
from, to interface{}
checkFunc func(t *testing.T)
}
checkColumnType := func(t *testing.T, fieldName string, unique bool) {
columnTypes, err := DB.Migrator().ColumnTypes(table)
if err != nil {
t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err)
}
var found gorm.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == fieldName {
found = columnType
}
}
if found == nil {
t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName)
}
if actualUnique, ok := found.Unique(); !ok || actualUnique != unique {
t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique)
}
}
checkIndex := func(t *testing.T, expected []gorm.Index) {
indexes, err := DB.Migrator().GetIndexes(table)
if err != nil {
t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err)
}
assert.ElementsMatch(t, expected, indexes)
}
uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T)
// UniqueAffectedByUniqueIndex is true
if DB.Dialector.Name() == "mysql" {
uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}}
checkNotUnique = func(t *testing.T) {
checkColumnType(t, "name", false)
checkIndex(t, nil)
}
checkUnique = func(t *testing.T) {
checkColumnType(t, "name", true)
checkIndex(t, []gorm.Index{uniqueConstraintIndex})
}
checkUniqueIndex = func(t *testing.T) {
checkColumnType(t, "name", true)
checkIndex(t, []gorm.Index{uniqueIndex})
}
checkMyIndex = func(t *testing.T) {
checkColumnType(t, "name", true)
checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex})
}
checkMulIndex = func(t *testing.T) {
checkColumnType(t, "name", false)
checkColumnType(t, "nick_name", false)
checkIndex(t, []gorm.Index{mulIndex})
}
} else {
checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) }
checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) }
checkUniqueIndex = func(t *testing.T) {
checkColumnType(t, "name", false)
checkIndex(t, []gorm.Index{uniqueIndex})
}
checkMyIndex = func(t *testing.T) {
checkColumnType(t, "name", true)
if !DB.Migrator().HasIndex(table, myIndex.Name()) {
t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name())
}
}
checkMulIndex = func(t *testing.T) {
checkColumnType(t, "name", false)
checkColumnType(t, "nick_name", false)
if !DB.Migrator().HasIndex(table, mulIndex.Name()) {
t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name())
}
}
}
tests := []TestCase{
{name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
{name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
{name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
{name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
{name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique},
{name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex},
{name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
{name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex},
{name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex},
{name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := DB.Migrator().DropTable(table); err != nil {
t.Fatalf("failed to drop table, got error: %v", err)
}
if err := DB.Table(table).AutoMigrate(test.from); err != nil {
t.Fatalf("failed to migrate table, got error: %v", err)
}
if err := DB.Table(table).AutoMigrate(test.to); err != nil {
t.Fatalf("failed to migrate table, got error: %v", err)
}
test.checkFunc(t)
})
}
if DB.Dialector.Name() != "sqlserver" {
// In SQLServer, If an index or constraint depends on the column,
// this column will not be able to run ALTER
// see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205
// may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106
tests = []TestCase{
{name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique},
{name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique},
{name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique},
}
}
if DB.Dialector.Name() == "mysql" {
compatibilityTests := []TestCase{
{name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique},
{name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique},
{name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex},
{name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex},
}
for _, test := range compatibilityTests {
t.Run(test.name, func(t *testing.T) {
if err := DB.Migrator().DropTable(table); err != nil {
t.Fatalf("failed to drop table, got error: %v", err)
}
if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil {
t.Fatalf("failed to create table, got error: %v", err)
}
if err := DB.Table(table).AutoMigrate(test.to); err != nil {
t.Fatalf("failed to migrate table, got error: %v", err)
}
test.checkFunc(t)
})
}
}
}
func testAutoMigrateDecimal(t *testing.T, model1, model2 any) []string {
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE ") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if decimal is not change: sql: %s", sql)
}
},
}
session := DB.Session(&gorm.Session{Logger: tracer})
DB.Migrator().DropTable(model1)
var modifySql []string
if err := session.AutoMigrate(model1); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
if err := session.AutoMigrate(model1); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
tracer2 := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
modifySql = append(modifySql, sql)
},
}
session2 := DB.Session(&gorm.Session{Logger: tracer2})
err := session2.Table("migrate_decimal_columns").Migrator().AutoMigrate(model2)
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
return modifySql
}
func decimalColumnsTest[T, T2 any](t *testing.T, expectedSql []string) {
var t1 T
var t2 T2
modSql := testAutoMigrateDecimal(t, t1, t2)
var alterSQL []string
for _, sql := range modSql {
if strings.HasPrefix(sql, "ALTER TABLE ") {
alterSQL = append(alterSQL, sql)
}
}
if len(alterSQL) != 3 {
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
}
for i := range alterSQL {
if alterSQL[i] != expectedSql[i] {
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
}
}
}
func TestAutoMigrateDecimal(t *testing.T) {
if DB.Dialector.Name() == "sqlserver" { // database/sql will replace numeric to decimal. so only support decimal.
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
}
expectedSql := []string{
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" decimal(8) NOT NULL`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" decimal(9,1) NOT NULL`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" decimal(9,2) NOT NULL`,
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
} else if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:numeric(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:numeric(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:numeric(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:numeric(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:numeric(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:numeric(9,2);not null" json:"recid3"`
}
expectedSql := []string{
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" TYPE numeric(8) USING "recid1"::numeric(8)`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" TYPE numeric(9,1) USING "recid2"::numeric(9,1)`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" TYPE numeric(9,2) USING "recid3"::numeric(9,2)`,
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
} else if DB.Dialector.Name() == "mysql" {
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
}
expectedSql := []string{
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid1` decimal(8) NOT NULL",
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid2` decimal(9,1) NOT NULL",
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid3` decimal(9,2) NOT NULL",
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
}
}

View File

@ -41,7 +41,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" || name == "gaussdb" {
if name := DB.Dialector.Name(); name == "postgres" {
stmt := gorm.Statement{DB: DB}
stmt.Parse(&Blog{})
stmt.Schema.LookUpField("ID").Unique = true
@ -142,9 +142,6 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
if name := DB.Dialector.Name(); name == "postgres" {
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
}
if name := DB.Dialector.Name(); name == "gaussdb" {
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
@ -267,14 +264,10 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" {
if name := DB.Dialector.Name(); name == "postgres" {
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
}
if name := DB.Dialector.Name(); name == "gaussdb" {
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
t.Fatalf("Failed to auto migrate, got error: %v", err)
@ -339,7 +332,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
DB.Model(&blog2).Association("LocaleTags").Find(&tags)
if !compareTags(tags, []string{"tag4"}) {
t.Fatalf("Should find 1 tags for EN Blog, but got %v", tags)
t.Fatalf("Should find 1 tags for EN Blog")
}
// Replace

View File

@ -37,7 +37,7 @@ func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) {
}
animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone)
DB.Save(&animal).Update("From", "a nice place") // The name field should be untouched
DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
DB.First(&animal, animal.Counter)
if animal.Name != "galeone" {
t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name)

View File

@ -696,10 +696,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "mysql" {
t.Skip("skip mysql due to it only allow unique constraint matching given keys")
}
type (
Level1 struct {
ID uint `gorm:"primary_key;"`

View File

@ -1,14 +1,12 @@
package tests_test
import (
"context"
"encoding/json"
"regexp"
"sort"
"strconv"
"sync"
"testing"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -309,189 +307,6 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
CheckUserUnscoped(t, *user6, user)
}
func TestNestedPreloadWithNestedJoin(t *testing.T) {
type (
Preload struct {
ID uint
Value string
NestedID uint
}
Join struct {
ID uint
Value string
NestedID uint
}
Nested struct {
ID uint
Preloads []*Preload
Join Join
ValueID uint
}
Value struct {
ID uint
Name string
Nested Nested
}
)
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
value1 := Value{
Name: "value",
Nested: Nested{
Preloads: []*Preload{
{Value: "p1"}, {Value: "p2"},
},
Join: Join{Value: "j1"},
},
}
value2 := Value{
Name: "value2",
Nested: Nested{
Preloads: []*Preload{
{Value: "p3"}, {Value: "p4"}, {Value: "p5"},
},
Join: Join{Value: "j2"},
},
}
values := []*Value{&value1, &value2}
if err := DB.Create(&values).Error; err != nil {
t.Errorf("failed to create value, got err: %v", err)
}
var find1 Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find1, value1)
var find2 Value
// Joins will automatically add Nested queries.
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value2)
var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, len(finds), 2)
AssertEqual(t, finds[0], value1)
AssertEqual(t, finds[1], value2)
}
func TestMergeNestedPreloadWithNestedJoin(t *testing.T) {
users := []User{
{
Name: "TestMergeNestedPreloadWithNestedJoin-1",
Manager: &User{
Name: "Alexis Manager",
Tools: []Tools{
{Name: "Alexis Tool 1"},
{Name: "Alexis Tool 2"},
},
},
},
{
Name: "TestMergeNestedPreloadWithNestedJoin-2",
Manager: &User{
Name: "Jinzhu Manager",
Tools: []Tools{
{Name: "Jinzhu Tool 1"},
{Name: "Jinzhu Tool 2"},
},
},
},
}
DB.Create(&users)
query := make([]string, 0)
sess := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
query = append(query, sql)
},
}})
var result []User
err := sess.
Joins("Manager").
Preload("Manager.Tools").
Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%").
Find(&result).Error
if err != nil {
t.Fatalf("failed to preload and find users: %v", err)
}
AssertEqual(t, result, users)
AssertEqual(t, len(query), 2) // Check preload queries are merged
if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) {
t.Fatalf("Expected first query to preload manager tools, got: %s", query[0])
}
}
func TestNestedPreloadWithPointerJoin(t *testing.T) {
type (
Preload struct {
ID uint
Value string
JoinID uint
}
Join struct {
ID uint
Value string
Preload Preload
NestedID uint
}
Nested struct {
ID uint
Join Join
ValueID uint
}
Value struct {
ID uint
Name string
Nested *Nested
}
)
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})
value := Value{
Name: "value",
Nested: &Nested{
Join: Join{
Value: "j1",
Preload: Preload{
Value: "p1",
},
},
},
}
if err := DB.Create(&value).Error; err != nil {
t.Errorf("failed to create value, got err: %v", err)
}
var find1 Value
err := DB.Table("values").Joins("Nested").Joins("Nested.Join").Preload("Nested.Join.Preload").First(&find1).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find1, value)
}
func TestEmbedPreload(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
@ -584,7 +399,7 @@ func TestEmbedPreload(t *testing.T) {
},
}, {
name: "nested address country",
preloads: map[string][]interface{}{"NestedAddress.Country": {}},
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
expect: Org{
ID: org.ID,
PostalAddress: EmbeddedAddress{
@ -614,6 +429,7 @@ func TestEmbedPreload(t *testing.T) {
},
}
DB = DB.Debug()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := Org{}

View File

@ -91,65 +91,6 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit()
}
func TestPreparedStmtLruFromTransaction(t *testing.T) {
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second})
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if err := tx.Error; err != nil {
t.Errorf("Failed to start transaction, got error %v\n", err)
}
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Commit().Error; err != nil {
t.Errorf("Failed to commit transaction, got error %v\n", err)
}
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2 := db.Begin()
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2.Commit()
// Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type.
// If the conversion is successful, ok will be true and conn will be the converted object;
// otherwise, ok will be false and conn will be nil.
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
// Get the number of statement keys stored in the PreparedStmtDB.
lens := len(conn.Stmts.Keys())
// Check if the number of stored statement keys is 0.
if lens == 0 {
// If the number is 0, it means there are no statements stored in the LRU cache.
// The test fails and an error message is output.
t.Fatalf("lru should not be empty")
}
// Wait for 40 seconds to give the statements in the cache enough time to expire.
time.Sleep(time.Second * 40)
// Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type.
AssertEqual(t, ok, true)
// Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds.
// If it is not 0, it means the statements in the cache have not expired as expected.
AssertEqual(t, len(conn.Stmts.Keys()), 0)
}
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
@ -175,9 +116,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts.Keys()), 2)
for _, stmt := range conn.Stmts.Keys() {
if stmt == "" {
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}
@ -185,6 +126,33 @@ func TestPreparedStmtDeadlock(t *testing.T) {
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}
@ -201,10 +169,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
}
}
func TestPreparedStmtClose(t *testing.T) {
func TestPreparedStmtReset(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_close", Config{})
user := *GetUser("prepared_stmt_reset", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
@ -213,77 +181,16 @@ func TestPreparedStmtClose(t *testing.T) {
}
pdb.Mux.Lock()
if len(pdb.Stmts.Keys()) == 0 {
if len(pdb.Stmts) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Close()
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts.Keys()) != 0 {
if len(pdb.Stmts) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}
func isUsingClosedConnError(err error) bool {
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
return err.Error() == "sql: statement is closed"
}
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}
// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
loopCount := 100
var wg sync.WaitGroup
var unexpectedError bool
writerFinish := make(chan struct{})
wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(writerFinish)
for j := 0; j < loopCount; j++ {
var tmp User
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil || isUsingClosedConnError(err) {
continue
}
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
unexpectedError = true
break
}
}(user.ID)
wg.Add(1)
go func() {
defer wg.Done()
<-writerFinish
pdb.Close()
}()
wg.Wait()
if unexpectedError {
t.Fatalf("should is a unexpected error")
}
}

View File

@ -554,16 +554,6 @@ func TestNot(t *testing.T) {
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{})
if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String())
}
}
func TestNotWithAllFields(t *testing.T) {
@ -632,21 +622,6 @@ func TestOr(t *testing.T) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
sub := dryDB.Clauses(clause.Where{
Exprs: []clause.Expression{
clause.OrConditions{
Exprs: []clause.Expression{
clause.Expr{SQL: "role = ?", Vars: []interface{}{"super_admin"}},
clause.Expr{SQL: "role = ?", Vars: []interface{}{"admin"}},
},
},
},
})
result = dryDB.Where(sub).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
@ -875,28 +850,6 @@ func TestOmitWithAllFields(t *testing.T) {
}
}
func TestMapColumns(t *testing.T) {
user := User{Name: "MapColumnsUser", Age: 12}
DB.Save(&user)
type result struct {
Name string
Nickname string
Age uint
}
var res result
DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res)
if res.Nickname != user.Name {
t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname)
}
if res.Name != "" {
t.Errorf("Expected res.Name to be empty, but got %s", res.Name)
}
if res.Age != user.Age {
t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age)
}
}
func TestPluckWithSelect(t *testing.T) {
users := []User{
{Name: "pluck_with_select_1", Age: 25},
@ -1127,10 +1080,6 @@ func TestSearchWithMap(t *testing.T) {
DB.First(&user, map[string]interface{}{"name": users[0].Name})
CheckUser(t, user, users[0])
user = User{}
DB.First(&user, map[string]interface{}{"users.name": users[0].Name})
CheckUser(t, user, users[0])
user = User{}
DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
CheckUser(t, user, users[1])
@ -1169,12 +1118,12 @@ func TestSearchWithStruct(t *testing.T) {
}
result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{})
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{})
if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) {
t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
}
@ -1235,6 +1184,7 @@ func TestSubQueryWithRaw(t *testing.T) {
Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
@ -1250,6 +1200,7 @@ func TestSubQueryWithRaw(t *testing.T) {
Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}).
Group("name"),
).Count(&count).Error
if err != nil {
t.Errorf("Expected to get no errors, but got %v", err)
}
@ -1376,7 +1327,7 @@ func TestQueryResetNullValue(t *testing.T) {
Number1 int64 `gorm:"default:NULL"`
Number2 uint64 `gorm:"default:NULL"`
Number3 float64 `gorm:"default:NULL"`
Now *time.Time `gorm:"default:NULL"`
Now *time.Time `gorm:"defalut:NULL"`
Item1Id string
Item1 *QueryResetItem `gorm:"references:ID"`
Item2Id string
@ -1453,22 +1404,3 @@ func TestQueryError(t *testing.T) {
}, Value: 1}).Scan(&p2).Error
AssertEqual(t, err, gorm.ErrModelValueRequired)
}
func TestQueryScanToArray(t *testing.T) {
err := DB.Create(&User{Name: "testname1", Age: 10}).Error
if err != nil {
t.Fatal(err)
}
users := [2]*User{{Name: "1"}, {Name: "2"}}
err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error
if err != nil {
t.Fatal(err)
}
if users[0] == nil || users[0].Name != "testname1" {
t.Error("users[0] not covered")
}
if users[1] != nil {
t.Error("users[1] should be empty")
}
}

View File

@ -5,7 +5,6 @@ import (
"sort"
"strings"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
@ -127,7 +126,7 @@ func TestScanRows(t *testing.T) {
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
if err != nil {
t.Errorf("No error should happen, got %v", err)
t.Errorf("Not error should happen, got %v", err)
}
type Result struct {
@ -149,7 +148,7 @@ func TestScanRows(t *testing.T) {
})
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results, got %+v", results)
t.Errorf("Should find expected results")
}
var ages int
@ -159,105 +158,7 @@ func TestScanRows(t *testing.T) {
var name string
if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name {
t.Fatalf("failed to scan name, got error %v, name: %v", err, name)
}
}
func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) {
DB.Save(&User{})
rows, err := DB.Table("users").
Select(`
NULL AS bool_field,
NULL AS int_field,
NULL AS int8_field,
NULL AS int16_field,
NULL AS int32_field,
NULL AS int64_field,
NULL AS uint_field,
NULL AS uint8_field,
NULL AS uint16_field,
NULL AS uint32_field,
NULL AS uint64_field,
NULL AS float32_field,
NULL AS float64_field,
NULL AS string_field,
NULL AS time_field,
NULL AS time_ptr_field,
NULL AS embedded_int_field,
NULL AS nested_embedded_int_field,
NULL AS embedded_ptr_int_field
`).Rows()
if err != nil {
t.Errorf("No error should happen, got %v", err)
}
type NestedEmbeddedStruct struct {
NestedEmbeddedIntField int
NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"`
}
type EmbeddedStruct struct {
EmbeddedIntField int
NestedEmbeddedStruct `gorm:"embedded"`
}
type EmbeddedPtrStruct struct {
EmbeddedPtrIntField int
*NestedEmbeddedStruct `gorm:"embedded"`
}
type Result struct {
BoolField bool
IntField int
Int8Field int8
Int16Field int16
Int32Field int32
Int64Field int64
UIntField uint
UInt8Field uint8
UInt16Field uint16
UInt32Field uint32
UInt64Field uint64
Float32Field float32
Float64Field float64
StringField string
TimeField time.Time
TimePtrField *time.Time
EmbeddedStruct `gorm:"embedded"`
*EmbeddedPtrStruct `gorm:"embedded"`
}
currTime := time.Now()
reusedVar := Result{
BoolField: true,
IntField: 1,
Int8Field: 1,
Int16Field: 1,
Int32Field: 1,
Int64Field: 1,
UIntField: 1,
UInt8Field: 1,
UInt16Field: 1,
UInt32Field: 1,
UInt64Field: 1,
Float32Field: 1.1,
Float64Field: 1.1,
StringField: "hello",
TimeField: currTime,
TimePtrField: &currTime,
EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}},
}
for rows.Next() {
if err := DB.ScanRows(rows, &reusedVar); err != nil {
t.Errorf("should get no error, but got %v", err)
}
}
if !reflect.DeepEqual(reusedVar, Result{}) {
t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar)
t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name)
}
}

View File

@ -84,9 +84,7 @@ func TestComplexScopes(t *testing.T) {
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB {
return d.Where(DB.Or("b = 2").Or("c = 3"))
},
func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
@ -95,9 +93,7 @@ func TestComplexScopes(t *testing.T) {
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Where("z = 0").Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB {
return d.Or(DB.Where("b = 2").Or("c = 3"))
},
func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) },
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`,
@ -108,7 +104,7 @@ func TestComplexScopes(t *testing.T) {
func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) },
func(d *gorm.DB) *gorm.DB {
return d.
Or(DB.Scopes(
Or(d.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") },
func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") },
)).

View File

@ -45,7 +45,7 @@ type SerializerPostgresStruct struct {
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
func adaptorSerializerModel(s *SerializerStruct) interface{} {
if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
if DB.Dialector.Name() == "postgres" {
sps := SerializerPostgresStruct(*s)
return &sps
}

View File

@ -388,7 +388,7 @@ func TestToSQL(t *testing.T) {
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{})
})
assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'foo' AND "users"."age" = 20 AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql)
// last and unscoped
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
@ -487,7 +487,7 @@ func replaceQuoteInSQL(sql string) string {
// convert dialect special quote into double quote
switch DB.Dialector.Name() {
case "postgres", "gaussdb":
case "postgres":
sql = strings.ReplaceAll(sql, `"`, `"`)
case "mysql", "sqlite":
sql = strings.ReplaceAll(sql, "`", `"`)

View File

@ -2,11 +2,8 @@ package tests_test
import (
"regexp"
"sync"
"testing"
"gorm.io/driver/gaussdb"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
@ -175,164 +172,3 @@ func TestTableWithNamer(t *testing.T) {
t.Errorf("Table with namer, got %v", sql)
}
}
func TestPostgresTableWithIdentifierLength(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type LongString struct {
ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
}
t.Run("default", func(t *testing.T) {
db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("naming strategy", func(t *testing.T) {
db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
NamingStrategy: schema.NamingStrategy{},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("namer", func(t *testing.T) {
uname := "custom_unique_name"
db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{
NamingStrategy: mockUniqueNamingStrategy{
UName: uname,
},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if key != uname {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
}
func TestGaussDBTableWithIdentifierLength(t *testing.T) {
if DB.Dialector.Name() != "gaussdb" {
return
}
type LongString struct {
ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
}
t.Run("default", func(t *testing.T) {
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("naming strategy", func(t *testing.T) {
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
NamingStrategy: schema.NamingStrategy{},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("namer", func(t *testing.T) {
uname := "custom_unique_name"
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
NamingStrategy: mockUniqueNamingStrategy{
UName: uname,
},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if key != uname {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
}
type mockUniqueNamingStrategy struct {
UName string
schema.NamingStrategy
}
func (a mockUniqueNamingStrategy) UniqueName(table, column string) string {
return a.UName
}

View File

@ -1,6 +1,6 @@
#!/bin/bash -e
dialects=("sqlite" "mysql" "postgres" "gaussdb" "sqlserver" "tidb")
dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
if [[ $(pwd) == *"gorm/tests"* ]]; then
cd ..
@ -16,22 +16,21 @@ then
fi
# SqlServer for Mac M1
if [[ -z $GITHUB_ACTION && -d tests ]]; then
if [[ -z $GITHUB_ACTION ]]; then
if [ -d tests ]
then
cd tests
if [[ $(uname -a) == *" arm64" ]]; then
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d --wait
MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true
go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true
for query in \
"IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" \
"IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" \
"IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];"
do
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "$query" > /dev/null || true
done
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true
SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true
else
MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d --wait
docker-compose start
fi
cd ..
fi
fi

View File

@ -1,4 +1,3 @@
//go:debug x509negativeserial=1
package tests_test
import (
@ -8,7 +7,6 @@ import (
"path/filepath"
"time"
"gorm.io/driver/gaussdb"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@ -22,8 +20,7 @@ var DB *gorm.DB
var (
mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
gaussdbDSN = "user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai"
sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master"
sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
)
@ -46,6 +43,9 @@ func init() {
}
RunMigrations()
if DB.Dialector.Name() == "sqlite" {
DB.Exec("PRAGMA foreign_keys = ON")
}
}
}
@ -67,15 +67,6 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
DSN: dbDSN,
PreferSimpleProtocol: true,
}), cfg)
case "gaussdb":
log.Println("testing gaussdb...")
if dbDSN == "" {
dbDSN = gaussdbDSN
}
db, err = gorm.Open(gaussdb.New(gaussdb.Config{
DSN: dbDSN,
PreferSimpleProtocol: true,
}), cfg)
case "sqlserver":
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930
@ -98,10 +89,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default:
log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg)
if err == nil {
db.Exec("PRAGMA foreign_keys = ON")
}
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg)
}
if err != nil {
@ -119,7 +107,7 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
@ -68,7 +67,7 @@ func TestTransaction(t *testing.T) {
return tx5.First(&User{}, "name = ?", "transaction-2").Error
})
}); err != nil {
t.Fatalf("prepare statement and nested transaction coexist" + err.Error())
t.Fatalf("prepare statement and nested transcation coexist" + err.Error())
}
})
}
@ -298,74 +297,6 @@ func TestNestedTransactionWithBlock(t *testing.T) {
}
}
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return callback(ctx, tx)
})
}
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)
if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
tx.Create(&user)
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
tx1.Create(&user1)
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
tx2.Create(&user2)
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return errors.New("inner rollback")
}); err == nil {
t.Fatalf("nested transaction has no error")
}
return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked parent record")
}
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should not find rollbacked nested record")
}
}
func TestDisabledNestedTransaction(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})
@ -460,6 +391,7 @@ func TestTransactionWithHooks(t *testing.T) {
return tx2.Scan(&User{}).Error
})
})
if err != nil {
t.Error(err)
}
@ -473,20 +405,8 @@ func TestTransactionWithHooks(t *testing.T) {
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
})
})
if err != nil {
t.Error(err)
}
}
func TestTransactionWithDefaultTimeout(t *testing.T) {
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
tx := db.Begin()
time.Sleep(3 * time.Second)
if err = tx.Find(&User{}).Error; err == nil {
t.Errorf("should return error when transaction timeout, got error %v", err)
}
}

Some files were not shown because too many files have changed in this diff Show More