Compare commits

...

26 Commits

Author SHA1 Message Date
贾一饼
49eaeacb89
optimize: field.ReflectValueOf (#7530)
Some checks failed
tests / mysql (mysql:5.7, 1.24, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.23, ubuntu-latest) (push) Failing after 59s
tests / mysql (mysql:8, 1.24, ubuntu-latest) (push) Failing after 5s
tests / sqlite (1.23, ubuntu-latest) (push) Failing after 1m26s
tests / mysql (mysql:9, 1.23, ubuntu-latest) (push) Failing after 36s
tests / mysql (mysql:9, 1.24, ubuntu-latest) (push) Failing after 32s
tests / mariadb (mariadb:latest, 1.24, ubuntu-latest) (push) Failing after 9s
tests / mariadb (mariadb:latest, 1.23, ubuntu-latest) (push) Failing after 22s
tests / postgres (postgres:13, 1.23, ubuntu-latest) (push) Failing after 24s
tests / mysql (mysql:5.7, 1.23, ubuntu-latest) (push) Failing after 2m1s
tests / postgres (postgres:14, 1.23, ubuntu-latest) (push) Failing after 20s
tests / postgres (postgres:14, 1.24, ubuntu-latest) (push) Failing after 6s
tests / postgres (postgres:15, 1.24, ubuntu-latest) (push) Failing after 14s
tests / postgres (postgres:15, 1.23, ubuntu-latest) (push) Failing after 29s
tests / postgres (postgres:latest, 1.23, ubuntu-latest) (push) Failing after 26s
tests / postgres (postgres:latest, 1.24, ubuntu-latest) (push) Failing after 18s
tests / postgres (postgres:13, 1.24, ubuntu-latest) (push) Failing after 2m55s
tests / sqlite (1.24, ubuntu-latest) (push) Failing after 4m53s
tests / tidb (v6.5.0, 1.24, ubuntu-latest) (push) Failing after 2m6s
golangci-lint / lint (push) Failing after 6m41s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.23, ubuntu-latest) (push) Failing after 3m31s
tests / gaussdb (opengauss/opengauss:7.0.0-RC1.B023, 1.24, ubuntu-latest) (push) Failing after 3m38s
tests / sqlserver (1.23, ubuntu-latest) (push) Failing after 7m36s
tests / tidb (v6.5.0, 1.23, ubuntu-latest) (push) Failing after 12m41s
tests / sqlserver (1.24, ubuntu-latest) (push) Failing after 14m12s
Stale / stale (push) Successful in 6s
Close Missing Playground issues / stale (push) Successful in 6s
Close invalid questions issues / stale (push) Successful in 5s
2025-07-23 13:02:49 +08:00
贾一饼
52b4744410
optimize: performance optimization (#7526) 2025-07-22 18:32:28 +08:00
Jinzhu
9af6d510b5 Fix query when map keys include table-qualified column names, close #7507 2025-07-22 14:21:04 +08:00
Jinzhu
c63374f5d1 Don't request LastInsertID from database if not necessary, close #7469 2025-07-21 17:55:20 +08:00
jc
b9c7e562b0
fix(schema): check the hook function parameter type (#7468)
* fix(schema): Check the callback function parameter type

* fix log

* fix
2025-07-21 17:06:16 +08:00
Riseif
985940f0d8
should check inner condition length (#7512) 2025-07-21 11:57:12 +08:00
moseszane168
991c2d4891
Add GaussDB Database Support (#7508)
* support gaussdb

* use github CI

* change function name

* use gorm.io/driver/gaussdb

---------

Co-authored-by: bing.ma <bing.ma@daocloud.io>
2025-07-21 10:46:58 +08:00
Jinzhu
751a6dde7a
Call after initialize for gorm.Config (#7518) 2025-07-15 12:05:03 +08:00
贾一饼
2f4925e017
A little optimization for filed.ValueOf (#7499)
Co-authored-by: 贾一饼 <Boyang.Liang@apulis.com>
2025-07-07 11:15:10 +08:00
Eshan-Jogwar
1e8baf5459
fixes #7486 (#7492)
* fixes #7486

* Added A test case for subset changes of model

* completed the test file for check_subset_model_change_test.go
2025-06-25 11:11:08 +08:00
Salent Olivick
842ee527eb
fix decimal migrate error.(#7450) (#7450)
Signed-off-by: Chise1 <chise123@live.com>
2025-06-06 10:35:01 +08:00
enomotodev
23c0d7cf05
test: update MySQL test matrix to use official images and add 9.0, 8.4 versions (#7476)
* test: update MySQL test matrix to use official images and add 9.0, 8.4 versions

* test: use major version tags for MySQL test matrix
2025-06-06 10:10:23 +08:00
Jinzhu
718eae4fdd fix tests for mysql 9.0 2025-06-05 19:34:13 +08:00
Jinzhu
49b01a3e93 Fix Generics Scan, close https://github.com/go-gorm/playground/pull/803 2025-05-29 14:23:57 +08:00
Jinzhu
c44405a25b
Implement Generics API (#7424)
* Implement Generics API

* Add more generics tests

* Add more tests and Take method

* use delayed‑ops pipeline for generics API

* fix generics tests for mysql

* Support SubQuery for Generics

* Add clause.JoinTable helper method

* Fix golangci-lint error

* Complete the design and implementation of generic version Join

* improve generics version Joins support

* allow configuring select/omit columns for joins via subqueries

* finish generic version Preload

* handle error of generics Joins/Preload

* fix tests

* Add LimitPerRecord for generic version Preload

* fix tests for mysql 5.7

* test for nested generic version Join/Preload

* Add WithResult support for generics API

* test reuse generics db conditions

* fix data race

* remove ExampleLRU test

* Add default transaction timeout support

* fix test
2025-05-25 15:40:40 +08:00
Name
751c1d6b45
perf(schema): avoid redundant strings.ToLower call (#7464)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-25 09:27:21 +08:00
codingplz
8e7ab46c1b
fix: return init dialector error (#7379)
* fix: return init dialector error

* mock defer

* fix: skip AfterInitialize

---------

Co-authored-by: wenyazhou.13 <wenyazhou.13@bytedance.com>
2025-05-22 10:53:47 +08:00
Name
e3037e4ef0
perf: break early on match failure in ParseConstraint (#7402)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-22 10:49:19 +08:00
pipipipipip
1204330419
feat: error message show field name (#7452)
* feat: error message show field name

* feat: failed to parse field

* feat: failed to parse field

---------

Co-authored-by: zyp <zyp>
2025-05-21 11:13:31 +08:00
Name
9703eb775f
perf: use strings.IndexByte to replace strings.Index (#7454)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:56 +08:00
Name
1c966e0d25
perf: use strings.Cut to replace strings.SplitN (#7455)
Co-authored-by: 1911860538 <alxps1911@gmail.com>
2025-05-21 10:35:23 +08:00
Jinzhu
e5b867e785 remove unnecessary session-level configuration for prepared statements 2025-05-07 14:56:49 +08:00
iTanken
8c4e8e2d2a
fix: int type variable defaultMaxSize overflows in 32-bit environment (#7439)
Refs: #7435
2025-04-27 14:05:16 +08:00
Zhaodong Xie
a827495be1
Preparestmt use LRU Map instead default map (#7435)
* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 支持lru淘汰preparestmt cache

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* 只使用lru

* change const export

* Add stmt_store

* refact prepare stmt store

* Rename lru store

* change const export

* ADD UT

* format code and add session level prepare stmt config

* code format according to golinter ci

* ADD UT

---------

Co-authored-by: xiezhaodong <xiezhaodong@bytedance.com>
Co-authored-by: Jinzhu <wosmvp@gmail.com>
2025-04-25 16:22:26 +08:00
Jinzhu
489a563293 only check new issues for golangci linter 2025-04-17 15:30:17 +08:00
Jinzhu
42bd4f603c
use golangci replace reviewdog (#7426)
* use golangci replace reviewdog

* Update golangci config
2025-04-17 11:55:13 +08:00
48 changed files with 3951 additions and 443 deletions

26
.github/workflows/golangci-lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: golangci-lint
on:
push:
branches:
- main
- master
pull_request:
permissions:
contents: read
pull-requests: read
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
version: v2.0
only-new-issues: true

View File

@ -1,22 +0,0 @@
name: reviewdog
on: [pull_request]
jobs:
golangci-lint:
name: runner / golangci-lint
runs-on: ubuntu-latest
steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v2
- name: Setup reviewdog
uses: reviewdog/action-setup@v1
- name: gofumpt -s with reviewdog
env:
REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
go install mvdan.cc/gofumpt@v0.2.0
gofumpt -e -d . | \
reviewdog -name="gofumpt" -f=diff -f.diff.strip=0 -reporter=github-pr-review

View File

@ -41,7 +41,7 @@ jobs:
mysql:
strategy:
matrix:
dbversion: ['mysql/mysql-server:latest', 'mysql:5.7']
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
go: ['1.23', '1.24']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -240,3 +240,71 @@ jobs:
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh
gaussdb:
strategy:
matrix:
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
go: ['1.23', '1.24']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
services:
gaussdb:
image: ${{ matrix.dbversion }}
env:
# GaussDB has password limitations
GS_PASSWORD: Gaussdb@123
TZ: Asia/Shanghai
ports:
- 9950:5432
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Waiting for GaussDB to be ready
run: |
container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}")
if [ -z "$container_name" ]; then
echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image."
exit 1
fi
max_retries=12
retry_count=0
if [ -t 0 ]; then
TTY_FLAG="-t"
else
TTY_FLAG=""
fi
while [ $retry_count -lt $max_retries ]; do
if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'"
then
echo "Creating database gorm..."
sql_file='/tmp/create_database.sql'
echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file}
docker cp "${sql_file}" "${container_name}":"${sql_file}"
docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'"
echo "Database initialization completed."
break
fi
echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)"
sleep 10
((++retry_count))
done
exit 0
- name: go mod package cache
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
- name: Tests
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

View File

@ -1,7 +1,9 @@
version: "2"
linters:
default: standard
enable:
- cyclop
- exportloopref
- gocritic
- gosec
- ineffassign
@ -9,12 +11,9 @@ linters:
- prealloc
- unconvert
- unparam
- goimports
- whitespace
linters-settings:
whitespace:
multi-func: true
goimports:
local-prefixes: gorm.io/gorm
formatters:
enable:
- gofumpt
- goimports

View File

@ -53,12 +53,16 @@ func Create(config *Config) func(db *gorm.DB) {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
if field.Readable {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
}
if len(fromColumns) > 0 {
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
@ -89,6 +93,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -103,6 +111,12 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
if db.RowsAffected == 0 {
return
}
@ -112,6 +126,16 @@ func Create(config *Config) func(db *gorm.DB) {
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil ||
!db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue ||
!db.Statement.Schema.PrioritizedPrimaryField.Readable {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
@ -122,14 +146,6 @@ func Create(config *Config) func(db *gorm.DB) {
return
}
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {

View File

@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
return
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
}
}

View File

@ -103,11 +103,11 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
join0, join1, cut := strings.Cut(join, ".")
if cut {
if _, ok := relationships.Relations[join0]; ok && name == join0 {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
nestedJoins = append(nestedJoins, join1)
}
}
}
@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
}

View File

@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
specifiedRelationsName := make(map[string]interface{})
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) {
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) {
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
relations = guessNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) {
}
}
if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) {
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}
parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{

View File

@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}

View File

@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}

View File

@ -448,6 +448,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.

View File

@ -1,5 +1,7 @@
package clause
import "gorm.io/gorm/utils"
type JoinType string
const (
@ -9,6 +11,30 @@ const (
RightJoin JoinType = "RIGHT"
)
type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}
func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}
func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}
func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}
// Join clause for from
type Join struct {
Type JoinType
@ -18,6 +44,12 @@ type Join struct {
Expression Expression
}
func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)

View File

@ -1,6 +1,7 @@
package gorm
import (
"context"
"database/sql"
"errors"
"fmt"
@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0]
}
ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
}
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction
}

605
generics.go Normal file
View File

@ -0,0 +1,605 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type result struct {
Result sql.Result
RowsAffected int64
}
func (info *result) ModifyStatement(stmt *Statement) {
stmt.Result = info
}
// Build implements clause.Expression interface
func (result) Build(clause.Builder) {
}
func WithResult() *result {
return &result{}
}
type Interface[T any] interface {
Raw(sql string, values ...interface{}) ExecInterface[T]
Exec(ctx context.Context, sql string, values ...interface{}) error
CreateInterface[T]
}
type CreateInterface[T any] interface {
ChainInterface[T]
Table(name string, args ...interface{}) CreateInterface[T]
Create(ctx context.Context, r *T) error
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
}
type ChainInterface[T any] interface {
ExecInterface[T]
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
Where(query interface{}, args ...interface{}) ChainInterface[T]
Not(query interface{}, args ...interface{}) ChainInterface[T]
Or(query interface{}, args ...interface{}) ChainInterface[T]
Limit(offset int) ChainInterface[T]
Offset(offset int) ChainInterface[T]
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
Select(query string, args ...interface{}) ChainInterface[T]
Omit(columns ...string) ChainInterface[T]
MapColumns(m map[string]string) ChainInterface[T]
Distinct(args ...interface{}) ChainInterface[T]
Group(name string) ChainInterface[T]
Having(query interface{}, args ...interface{}) ChainInterface[T]
Order(value interface{}) ChainInterface[T]
Build(builder clause.Builder)
Delete(ctx context.Context) (rowsAffected int, err error)
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
Updates(ctx context.Context, t T) (rowsAffected int, err error)
Count(ctx context.Context, column string) (result int64, err error)
}
type ExecInterface[T any] interface {
Scan(ctx context.Context, r interface{}) error
First(context.Context) (T, error)
Last(ctx context.Context) (T, error)
Take(context.Context) (T, error)
Find(ctx context.Context) ([]T, error)
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
Row(ctx context.Context) *sql.Row
Rows(ctx context.Context) (*sql.Rows, error)
}
type JoinBuilder interface {
Select(...string) JoinBuilder
Omit(...string) JoinBuilder
Where(query interface{}, args ...interface{}) JoinBuilder
Not(query interface{}, args ...interface{}) JoinBuilder
Or(query interface{}, args ...interface{}) JoinBuilder
}
type PreloadBuilder interface {
Select(...string) PreloadBuilder
Omit(...string) PreloadBuilder
Where(query interface{}, args ...interface{}) PreloadBuilder
Not(query interface{}, args ...interface{}) PreloadBuilder
Or(query interface{}, args ...interface{}) PreloadBuilder
Limit(offset int) PreloadBuilder
Offset(offset int) PreloadBuilder
Order(value interface{}) PreloadBuilder
LimitPerRecord(num int) PreloadBuilder
}
type op func(*DB) *DB
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
v := &g[T]{
db: db,
ops: make([]op, 0, 5),
}
if len(opts) > 0 {
v.ops = append(v.ops, func(db *DB) *DB {
return db.Clauses(opts...)
})
}
v.createG = &createG[T]{
chainG: chainG[T]{
execG: execG[T]{g: v},
},
}
return v
}
type g[T any] struct {
*createG[T]
db *DB
ops []op
}
func (g *g[T]) apply(ctx context.Context) *DB {
db := g.db
if !db.DryRun {
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
}
for _, op := range g.ops {
db = op(db)
}
return db
}
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
return execG[T]{g: &g[T]{
db: c.db,
ops: append(c.ops, func(db *DB) *DB {
return db.Raw(sql, values...)
}),
}}
}
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
return c.apply(ctx).Exec(sql, values...).Error
}
type createG[T any] struct {
chainG[T]
}
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
return createG[T]{c.with(func(db *DB) *DB {
return db.Table(name, args...)
})}
}
func (c createG[T]) Create(ctx context.Context, r *T) error {
return c.g.apply(ctx).Create(r).Error
}
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
}
type chainG[T any] struct {
execG[T]
}
func (c chainG[T]) getInstance() *DB {
var r T
return c.g.apply(context.Background()).Model(r).getInstance()
}
func (c chainG[T]) with(v op) chainG[T] {
return chainG[T]{
execG: execG[T]{g: &g[T]{
db: c.g.db,
ops: append(append([]op(nil), c.g.ops...), v),
}},
}
}
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
return c.with(func(db *DB) *DB {
for _, fc := range scopes {
fc(db.Statement)
}
return db
})
}
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Table(name, args...)
})
}
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Where(query, args...)
})
}
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Not(query, args...)
})
}
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Or(query, args...)
})
}
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Limit(offset)
})
}
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Offset(offset)
})
}
type joinBuilder struct {
db *DB
}
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
q.db.Where(query, args...)
return q
}
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
q.db.Select(columns)
return q
}
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
q.db.Omit(columns...)
return q
}
type preloadBuilder struct {
limitPerRecord int
db *DB
}
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
q.db.Where(query, args...)
return q
}
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
q.db.Select(columns)
return q
}
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
q.db.Omit(columns...)
return q
}
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
q.db.Limit(limit)
return q
}
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
q.db.Offset(offset)
return q
}
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
q.db.Order(value)
return q
}
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
q.limitPerRecord = num
return q
}
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
if jt.Table == "" {
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
}
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
if on != nil {
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
db.AddError(err)
}
}
j := join{
Name: jt.Association,
Alias: jt.Table,
Selects: q.db.Statement.Selects,
Omits: q.db.Statement.Omits,
JoinType: jt.Type,
}
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
if jt.Subquery != nil {
joinType := j.JoinType
if joinType == "" {
joinType = clause.LeftJoin
}
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
stmt := db.getInstance().Statement
if len(j.Selects) == 0 {
j.Selects = stmt.Selects
}
if len(j.Omits) == 0 {
j.Omits = stmt.Omits
}
}
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
if j.On != nil {
expr.SQL += " ON ?"
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
}
j.Expression = expr
}
db.Statement.Joins = append(db.Statement.Joins, j)
sort.Slice(db.Statement.Joins, func(i, j int) bool {
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
})
return db
})
}
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Select(query, args...)
})
}
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Omit(columns...)
})
}
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.MapColumns(m)
})
}
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Distinct(args...)
})
}
func (c chainG[T]) Group(name string) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Group(name)
})
}
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Having(query, args...)
})
}
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Order(value)
})
}
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
return c.with(func(db *DB) *DB {
return db.Preload(association, func(tx *DB) *DB {
q := preloadBuilder{db: tx.getInstance()}
if query != nil {
if err := query(&q); err != nil {
db.AddError(err)
}
}
relation, ok := db.Statement.Schema.Relationships.Relations[association]
if !ok {
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
relationships := db.Statement.Schema.Relationships
for _, field := range preloadFields {
var ok bool
relation, ok = relationships.Relations[field]
if ok {
relationships = relation.FieldSchema.Relationships
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
} else {
db.AddError(fmt.Errorf("relation %s not found", association))
return nil
}
}
if q.limitPerRecord > 0 {
if relation.JoinTable != nil {
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
return tx
}
refColumns := []clause.Column{}
for _, rel := range relation.References {
if rel.OwnPrimaryKey {
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
}
}
if len(refColumns) != 0 {
selectExpr := clause.CommaExpression{}
for _, column := range q.db.Statement.Selects {
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
}
if len(selectExpr.Exprs) == 0 {
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
}
partitionBy := clause.CommaExpression{}
for _, column := range refColumns {
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
}
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
vars := []interface{}{partitionBy}
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
vars = append(vars, orderBy)
} else {
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
}})
}
vars = append(vars, rnnColumn)
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
q.db.Clauses(clause.Select{Expression: selectExpr})
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
}
}
return q.db
})
})
}
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
r := new(T)
res := c.g.apply(ctx).Delete(r)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
var r T
res := c.g.apply(ctx).Model(r).Update(name, value)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
res := c.g.apply(ctx).Updates(t)
return int(res.RowsAffected), res.Error
}
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
var r T
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
return
}
func (c chainG[T]) Build(builder clause.Builder) {
subdb := c.getInstance()
subdb.Logger = logger.Discard
subdb.DryRun = true
if stmt, ok := builder.(*Statement); ok {
if subdb.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = subdb.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
builder.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
}
}
type execG[T any] struct {
g *g[T]
}
func (g execG[T]) First(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).First(&r).Error
return r, err
}
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
var r T
err := g.g.apply(ctx).Model(r).Find(result).Error
return err
}
func (g execG[T]) Last(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Last(&r).Error
return r, err
}
func (g execG[T]) Take(ctx context.Context) (T, error) {
var r T
err := g.g.apply(ctx).Take(&r).Error
return r, err
}
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
var r []T
err := g.g.apply(ctx).Find(&r).Error
return r, err
}
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
var data []T
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
return fc(data, batch)
}).Error
}
func (g execG[T]) Row(ctx context.Context) *sql.Row {
return g.g.apply(ctx).Row()
}
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
return g.g.apply(ctx).Rows()
}

29
gorm.go
View File

@ -22,6 +22,8 @@ type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
DefaultTransactionTimeout time.Duration
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
@ -34,6 +36,11 @@ type Config struct {
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// PrepareStmt cache support LRU expired,
// default maxsize=int64 Max value and ttl=1h
PrepareStmtMaxSize int
PrepareStmtTTL time.Duration
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
@ -130,12 +137,24 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
return isConfig && !isConfig2
})
if len(opts) > 0 {
if c, ok := opts[0].(*Config); ok {
config = c
} else {
opts = append([]Option{config}, opts...)
}
}
var skipAfterInitialize bool
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if skipAfterInitialize {
return
}
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
@ -187,6 +206,10 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
// DB is not initialized, so we skip AfterInitialize
skipAfterInitialize = true
return
}
if config.TranslateError {
@ -197,7 +220,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
@ -268,7 +291,7 @@ func (db *DB) Session(config *Session) *DB {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool)
preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
@ -514,7 +537,7 @@ func (db *DB) Use(plugin Plugin) error {
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)

493
internal/lru/lru.go Normal file
View File

@ -0,0 +1,493 @@
package lru
// golang -lru
// https://github.com/hashicorp/golang-lru
import (
"sync"
"time"
)
// EvictCallback is used to get a callback when a cache entry is evicted
type EvictCallback[K comparable, V any] func(key K, value V)
// LRU implements a thread-safe LRU with expirable entries.
type LRU[K comparable, V any] struct {
size int
evictList *LruList[K, V]
items map[K]*Entry[K, V]
onEvict EvictCallback[K, V]
// expirable options
mu sync.Mutex
ttl time.Duration
done chan struct{}
// buckets for expiration
buckets []bucket[K, V]
// uint8 because it's number between 0 and numBuckets
nextCleanupBucket uint8
}
// bucket is a container for holding entries to be expired
type bucket[K comparable, V any] struct {
entries map[K]*Entry[K, V]
newestEntry time.Time
}
// noEvictionTTL - very long ttl to prevent eviction
const noEvictionTTL = time.Hour * 24 * 365 * 10
// because of uint8 usage for nextCleanupBucket, should not exceed 256.
// casting it as uint8 explicitly requires type conversions in multiple places
const numBuckets = 100
// NewLRU returns a new thread-safe cache with expirable entries.
//
// Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off.
//
// Providing 0 TTL turns expiring off.
//
// Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely.
func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] {
if size < 0 {
size = 0
}
if ttl <= 0 {
ttl = noEvictionTTL
}
res := LRU[K, V]{
ttl: ttl,
size: size,
evictList: NewList[K, V](),
items: make(map[K]*Entry[K, V]),
onEvict: onEvict,
done: make(chan struct{}),
}
// initialize the buckets
res.buckets = make([]bucket[K, V], numBuckets)
for i := 0; i < numBuckets; i++ {
res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])}
}
// enable deleteExpired() running in separate goroutine for cache with non-zero TTL
//
// Important: done channel is never closed, so deleteExpired() goroutine will never exit,
// it's decided to add functionality to close it in the version later than v2.
if res.ttl != noEvictionTTL {
go func(done <-chan struct{}) {
ticker := time.NewTicker(res.ttl / numBuckets)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
res.deleteExpired()
}
}
}(res.done)
}
return &res
}
// Purge clears the cache completely.
// onEvict is called for each evicted key.
func (c *LRU[K, V]) Purge() {
c.mu.Lock()
defer c.mu.Unlock()
for k, v := range c.items {
if c.onEvict != nil {
c.onEvict(k, v.Value)
}
delete(c.items, k)
}
for _, b := range c.buckets {
for _, ent := range b.entries {
delete(b.entries, ent.Key)
}
}
c.evictList.Init()
}
// Add adds a value to the cache. Returns true if an eviction occurred.
// Returns false if there was no eviction: the item was already in the cache,
// or the size was not exceeded.
func (c *LRU[K, V]) Add(key K, value V) (evicted bool) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
// Check for existing item
if ent, ok := c.items[key]; ok {
c.evictList.MoveToFront(ent)
c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed
ent.Value = value
ent.ExpiresAt = now.Add(c.ttl)
c.addToBucket(ent)
return false
}
// Add new item
ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl))
c.items[key] = ent
c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket
evict := c.size > 0 && c.evictList.Length() > c.size
// Verify size not exceeded
if evict {
c.removeOldest()
}
return evict
}
// Get looks up a key's value from the cache.
func (c *LRU[K, V]) Get(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
c.evictList.MoveToFront(ent)
return ent.Value, true
}
return
}
// Contains checks if a key is in the cache, without updating the recent-ness
// or deleting it for being stale.
func (c *LRU[K, V]) Contains(key K) (ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
_, ok = c.items[key]
return ok
}
// Peek returns the key value (or undefined if not found) without updating
// the "recently used"-ness of the key.
func (c *LRU[K, V]) Peek(key K) (value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
var ent *Entry[K, V]
if ent, ok = c.items[key]; ok {
// Expired item check
if time.Now().After(ent.ExpiresAt) {
return value, false
}
return ent.Value, true
}
return
}
// Remove removes the provided key from the cache, returning if the
// key was contained.
func (c *LRU[K, V]) Remove(key K) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ent, ok := c.items[key]; ok {
c.removeElement(ent)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
return ent.Key, ent.Value, true
}
return
}
// GetOldest returns the oldest entry
func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.evictList.Back(); ent != nil {
return ent.Key, ent.Value, true
}
return
}
func (c *LRU[K, V]) KeyValues() map[K]V {
c.mu.Lock()
defer c.mu.Unlock()
maps := make(map[K]V)
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
maps[ent.Key] = ent.Value
// keys = append(keys, ent.Key)
}
return maps
}
// Keys returns a slice of the keys in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Keys() []K {
c.mu.Lock()
defer c.mu.Unlock()
keys := make([]K, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
keys = append(keys, ent.Key)
}
return keys
}
// Values returns a slice of the values in the cache, from oldest to newest.
// Expired entries are filtered out.
func (c *LRU[K, V]) Values() []V {
c.mu.Lock()
defer c.mu.Unlock()
values := make([]V, 0, len(c.items))
now := time.Now()
for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() {
if now.After(ent.ExpiresAt) {
continue
}
values = append(values, ent.Value)
}
return values
}
// Len returns the number of items in the cache.
func (c *LRU[K, V]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.evictList.Length()
}
// Resize changes the cache size. Size of 0 means unlimited.
func (c *LRU[K, V]) Resize(size int) (evicted int) {
c.mu.Lock()
defer c.mu.Unlock()
if size <= 0 {
c.size = 0
return 0
}
diff := c.evictList.Length() - size
if diff < 0 {
diff = 0
}
for i := 0; i < diff; i++ {
c.removeOldest()
}
c.size = size
return diff
}
// Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close().
// func (c *LRU[K, V]) Close() {
// c.mu.Lock()
// defer c.mu.Unlock()
// select {
// case <-c.done:
// return
// default:
// }
// close(c.done)
// }
// removeOldest removes the oldest item from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeOldest() {
if ent := c.evictList.Back(); ent != nil {
c.removeElement(ent)
}
}
// removeElement is used to remove a given list element from the cache. Has to be called with lock!
func (c *LRU[K, V]) removeElement(e *Entry[K, V]) {
c.evictList.Remove(e)
delete(c.items, e.Key)
c.removeFromBucket(e)
if c.onEvict != nil {
c.onEvict(e.Key, e.Value)
}
}
// deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry
// in it to expire first.
func (c *LRU[K, V]) deleteExpired() {
c.mu.Lock()
bucketIdx := c.nextCleanupBucket
timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry)
// wait for newest entry to expire before cleanup without holding lock
if timeToExpire > 0 {
c.mu.Unlock()
time.Sleep(timeToExpire)
c.mu.Lock()
}
for _, ent := range c.buckets[bucketIdx].entries {
c.removeElement(ent)
}
c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets
c.mu.Unlock()
}
// addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock!
func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) {
bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets
e.ExpireBucket = bucketID
c.buckets[bucketID].entries[e.Key] = e
if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) {
c.buckets[bucketID].newestEntry = e.ExpiresAt
}
}
// removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock!
func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) {
delete(c.buckets[e.ExpireBucket].entries, e.Key)
}
// Cap returns the capacity of the cache
func (c *LRU[K, V]) Cap() int {
return c.size
}
// Entry is an LRU Entry
type Entry[K comparable, V any] struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Entry[K, V]
// The list to which this element belongs.
list *LruList[K, V]
// The LRU Key of this element.
Key K
// The Value stored with this element.
Value V
// The time this element would be cleaned up, optional
ExpiresAt time.Time
// The expiry bucket item was put in, optional
ExpireBucket uint8
}
// PrevEntry returns the previous list element or nil.
func (e *Entry[K, V]) PrevEntry() *Entry[K, V] {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// LruList represents a doubly linked list.
// The zero Value for LruList is an empty list ready to use.
type LruList[K comparable, V any] struct {
root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used
len int // current list Length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *LruList[K, V]) Init() *LruList[K, V] {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewList returns an initialized list.
func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() }
// Length returns the number of elements of list l.
// The complexity is O(1).
func (l *LruList[K, V]) Length() int { return l.len }
// Back returns the last element of list l or nil if the list is empty.
func (l *LruList[K, V]) Back() *Entry[K, V] {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List Value.
func (l *LruList[K, V]) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at).
func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] {
return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at)
}
// Remove removes e from its list, decrements l.len
func (l *LruList[K, V]) Remove(e *Entry[K, V]) V {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e.Value
}
// move moves e to next to at.
func (l *LruList[K, V]) move(e, at *Entry[K, V]) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, time.Time{}, &l.root)
}
// PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e.
func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] {
l.lazyInit()
return l.insertValue(k, v, expiresAt, &l.root)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}

View File

@ -0,0 +1,183 @@
package stmt_store
import (
"context"
"database/sql"
"math"
"sync"
"time"
"gorm.io/gorm/internal/lru"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
func (stmt *Stmt) Error() error {
return stmt.prepareErr
}
func (stmt *Stmt) Close() error {
<-stmt.prepared
if stmt.Stmt != nil {
return stmt.Stmt.Close()
}
return nil
}
// Store defines an interface for managing the caching operations of SQL statements (Stmt).
// This interface provides methods for creating new statements, retrieving all cache keys,
// getting cached statements, setting cached statements, and deleting cached statements.
type Store interface {
// New creates a new Stmt object and caches it.
// Parameters:
// ctx: The context for the request, which can carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy.
// connPool: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
// Returns:
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error)
// Keys returns a slice of all cache keys in the store.
Keys() []string
// Get retrieves a Stmt object from the store based on the given key.
// Parameters:
// key: The key used to look up the Stmt object.
// Returns:
// *Stmt: The found Stmt object, or nil if not found.
// bool: Indicates whether the corresponding Stmt object was successfully found.
Get(key string) (*Stmt, bool)
// Set stores the given Stmt object in the store and associates it with the specified key.
// Parameters:
// key: The key used to associate the Stmt object.
// value: The Stmt object to be stored.
Set(key string, value *Stmt)
// Delete removes the Stmt object corresponding to the specified key from the store.
// Parameters:
// key: The key associated with the Stmt object to be deleted.
Delete(key string)
}
// defaultMaxSize defines the default maximum capacity of the cache.
// Its value is the maximum value of the int64 type, which means that when the cache size is not specified,
// the cache can theoretically store as many elements as possible.
// (1 << 63) - 1 is the maximum value that an int64 type can represent.
const (
defaultMaxSize = math.MaxInt
// defaultTTL defines the default time-to-live (TTL) for each cache entry.
// When the TTL for cache entries is not specified, each cache entry will expire after 24 hours.
defaultTTL = time.Hour * 24
)
// New creates and returns a new Store instance.
//
// Parameters:
// - size: The maximum capacity of the cache. If the provided size is less than or equal to 0,
// it defaults to defaultMaxSize.
// - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0,
// it defaults to defaultTTL.
//
// This function defines an onEvicted callback that is invoked when a cache entry is evicted.
// The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously
// to release associated resources.
//
// Returns:
// - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size,
// eviction callback, and TTL.
func New(size int, ttl time.Duration) Store {
if size <= 0 {
size = defaultMaxSize
}
if ttl <= 0 {
ttl = defaultTTL
}
onEvicted := func(k string, v *Stmt) {
if v != nil {
go v.Close()
}
}
return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)}
}
type lruStore struct {
lru *lru.LRU[string, *Stmt]
}
func (s *lruStore) Keys() []string {
return s.lru.Keys()
}
func (s *lruStore) Get(key string) (*Stmt, bool) {
stmt, ok := s.lru.Get(key)
if ok && stmt != nil {
<-stmt.prepared
}
return stmt, ok
}
func (s *lruStore) Set(key string, value *Stmt) {
s.lru.Add(key, value)
}
func (s *lruStore) Delete(key string) {
s.lru.Remove(key)
}
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// New creates a new Stmt object for executing SQL queries.
// It caches the Stmt object for future use and handles preparation and error states.
// Parameters:
//
// ctx: Context for the request, used to carry deadlines, cancellation signals, etc.
// key: The key representing the SQL query, used for caching and preparing the statement.
// isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy.
// conn: A connection pool that provides database connections.
// locker: A synchronization lock that is unlocked after initialization to avoid deadlocks.
//
// Returns:
//
// *Stmt: A newly created statement object for executing SQL operations.
// error: An error if the statement preparation fails.
func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) {
// Create a Stmt object and set its Transaction property.
// The prepared channel is used to synchronize the statement preparation state.
cacheStmt := &Stmt{
Transaction: isTransaction,
prepared: make(chan struct{}),
}
// Cache the Stmt object with the associated key.
s.Set(key, cacheStmt)
// Unlock after completing initialization to prevent deadlocks.
locker.Unlock()
// Ensure the prepared channel is closed after the function execution completes.
defer close(cacheStmt.prepared)
// Prepare the SQL statement using the provided connection.
cacheStmt.Stmt, err = conn.PrepareContext(ctx, key)
if err != nil {
// If statement preparation fails, record the error and remove the invalid Stmt object from the cache.
cacheStmt.prepareErr = err
s.Delete(key)
return &Stmt{}, err
}
// Return the successfully prepared Stmt object.
return cacheStmt, nil
}

View File

@ -474,7 +474,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
@ -513,8 +512,19 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
}
}
// check precision
if realDataType == "decimal" || realDataType == "numeric" &&
regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore
precision, scale, ok := columnType.DecimalSize()
if ok {
if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) &&
!strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) {
alterColumn = true
}
}
} else {
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true

View File

@ -7,29 +7,35 @@ import (
"errors"
"reflect"
"sync"
"time"
"gorm.io/gorm/internal/stmt_store"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts map[string]*Stmt
Stmts stmt_store.Store
Mux *sync.RWMutex
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
// NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB.
//
// Parameters:
// - connPool: A connection pool that implements the ConnPool interface, used for managing database connections.
// - maxSize: The maximum number of prepared statements that can be stored in the statement store.
// - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed.
//
// Returns:
// - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration.
func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
ConnPool: connPool, // Assigns the provided connection pool to manage database connections.
Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL.
Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store.
}
}
// GetDBConn returns the underlying *sql.DB connection
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
@ -42,98 +48,41 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
return nil, ErrInvalidDB
}
// Close closes all prepared statements in the store
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, stmt := range db.Stmts {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
for _, key := range db.Stmts.Keys() {
db.Stmts.Delete(key)
}
}(stmt)
}
// setting db.Stmts to nil to avoid further using
db.Stmts = nil
}
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range sdb.Stmts {
go func(s *Stmt) {
// make sure the stmt must finish preparation first
<-s.prepared
if s.Stmt != nil {
_ = s.Close()
}
}(stmt)
}
sdb.Stmts = make(map[string]*Stmt)
// Reset Deprecated use Close instead
func (db *PreparedStmtDB) Reset() {
db.Close()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
return stmt, stmt.Error()
}
return *stmt, nil
}
db.Mux.RUnlock()
// retry
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
if db.Stmts != nil {
if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
return stmt, stmt.Error()
}
}
return *stmt, nil
}
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
// which cause by calling Close and executing SQL concurrently
if db.Stmts == nil {
db.Mux.Unlock()
return Stmt{}, ErrInvalidDB
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.Mux.Unlock()
return cacheStmt, nil
return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux)
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
@ -162,10 +111,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
db.Stmts.Delete(query)
}
}
return result, err
@ -176,11 +122,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
db.Stmts.Delete(query)
}
}
return rows, err
@ -230,11 +172,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return result, err
@ -245,11 +183,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Stmts.Delete(query)
}
}
return rows, err

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"reflect"
"strings"
"time"
"gorm.io/gorm/schema"
@ -244,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
for _, join := range db.Statement.Joins {
if join.Alias == aliasName {
names = append(strings.Split(join.Name, "."), names[len(names)-1])
break
}
}
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields

View File

@ -318,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) {
lowerVal := DataType(strings.ToLower(val))
switch lowerVal {
case Bool, Int, Uint, Float, String, Time, Bytes:
field.DataType = DataType(strings.ToLower(val))
field.DataType = lowerVal
default:
field.DataType = DataType(val)
}
@ -447,21 +448,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
func (field *Field) setupValuerAndSetter(modelType reflect.Type) {
// Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
fieldValue := v.Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Type() != modelType {
fieldValue := v.FieldByName(field.Name)
return fieldValue.Interface(), fieldValue.IsZero()
}
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
@ -503,13 +513,20 @@ func (field *Field) setupValuerAndSetter() {
// ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
case len(field.StructField.Index) == 1 && fieldIndex >= 0:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
return v.Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
if v.Type() != modelType {
return v.FieldByName(field.Name)
}
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)

View File

@ -105,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
idx = strings.IndexByte(tag, ',')
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
@ -122,7 +122,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) {
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"The composite tag of %s.%s cannot be empty",
"the composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return

View File

@ -78,7 +78,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err)
return nil
}
@ -663,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
break
}
}
@ -675,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint {
var (
name string
idx = strings.Index(str, ",")
idx = strings.IndexByte(str, ',')
settings = ParseTagSetting(str, ",")
)
@ -762,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref
}
func copyableDataType(str DataType) bool {
lowerStr := strings.ToLower(string(str))
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(strings.ToLower(string(str)), s) {
if strings.Contains(lowerStr, s) {
return false
}
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"go/ast"
"path"
"reflect"
"strings"
"sync"
@ -247,7 +248,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
field.setupValuerAndSetter(modelType)
}
prioritizedPrimaryField := schema.LookUpField("id")
@ -313,8 +314,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
case "func(*gorm.DB) error":
expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath())
if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath {
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
} else {
logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg)
// PASS
}
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}

View File

@ -47,14 +47,17 @@ type Statement struct {
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
Result *result
}
type join struct {
Name string
Alias string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
Expression clause.Expression
JoinType clause.JoinType
}
@ -205,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
case interface{ getInstance() *DB }:
cv := v.getInstance()
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if cv.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
sql = cv.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
cv.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
@ -321,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
arg, _ = valuer.Value()
}
curTable := stmt.Table
if curTable == "" {
curTable = clause.CurrentTable
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
@ -331,9 +341,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
if len(orConds.Exprs) == 1 {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
@ -351,7 +363,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
case map[string]interface{}:
keys := make([]string, 0, len(v))
@ -362,12 +378,16 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
column := clause.Column{Name: key, Table: curTable}
if strings.Contains(key, ".") {
column = clause.Column{Name: key}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
@ -376,10 +396,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
conds = append(conds, clause.IN{Column: column, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
}
}
default:
@ -406,9 +426,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -420,9 +440,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
}
}
}
@ -447,14 +467,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
}
}
}
@ -521,6 +541,7 @@ func (stmt *Statement) clone() *Statement {
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
Result: stmt.Result,
}
if stmt.SQL.Len() > 0 {

View File

@ -0,0 +1,88 @@
package tests_test
import (
"fmt"
"strings"
"testing"
"gorm.io/gorm"
)
type Man struct {
ID int
Age int
Name string
Detail string
}
// Panic-safe BeforeUpdate hook that checks for Changed("age")
func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in BeforeUpdate: %v", r)
}
}()
if !tx.Statement.Changed("age") {
return nil
}
return nil
}
func (m *Man) update(data interface{}) error {
return DB.Set("data", data).Model(m).Where("id = ?", m.ID).Updates(data).Error
}
func TestBeforeUpdateStatementChanged(t *testing.T) {
DB.AutoMigrate(&Man{})
type TestCase struct {
BaseObjects Man
change interface{}
expectError bool
}
testCases := []TestCase{
{
BaseObjects: Man{ID: 1, Age: 18, Name: "random-name"},
change: struct {
Age int
}{Age: 20},
expectError: false,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
}{Name: "name-only"},
expectError: true,
},
{
BaseObjects: Man{ID: 2, Age: 18, Name: "random-name"},
change: struct {
Name string
Age int
}{Name: "name-only", Age: 20},
expectError: false,
},
}
for _, test := range testCases {
DB.Create(&test.BaseObjects)
// below comment is stored for future reference
// err := DB.Set("data", test.change).Model(&test.BaseObjects).Where("id = ?", test.BaseObjects.ID).Updates(test.change).Error
err := test.BaseObjects.update(test.change)
if strings.Contains(fmt.Sprint(err), "panic in BeforeUpdate") {
if !test.expectError {
t.Errorf("unexpected panic in BeforeUpdate for input: %+v\nerror: %v", test.change, err)
}
} else {
if test.expectError {
t.Errorf("expected panic did not occur for input: %+v", test.change)
}
if err != nil {
t.Errorf("unexpected GORM error: %v", err)
}
}
}
}

View File

@ -1,6 +1,6 @@
services:
mysql:
image: 'mysql/mysql-server:latest'
image: 'mysql:latest'
ports:
- "127.0.0.1:9910:3306"
environment:
@ -18,7 +18,7 @@ services:
- POSTGRES_USER=gorm
- POSTGRES_PASSWORD=gorm
mssql:
image: '${MSSQL_IMAGE}:2022-latest'
image: '${MSSQL_IMAGE}:latest'
ports:
- "127.0.0.1:9930:1433"
environment:

View File

@ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) {
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
db.Logger = DB.Logger
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}

View File

@ -14,31 +14,48 @@ import (
)
func TestCreate(t *testing.T) {
user := *GetUser("create", Config{})
u1 := *GetUser("create", Config{})
if results := DB.Create(&user); results.Error != nil {
if results := DB.Create(&u1); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
if u1.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", u1.ID)
}
if user.CreatedAt.IsZero() {
if u1.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}
if user.UpdatedAt.IsZero() {
if u1.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}
var newUser User
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, user)
CheckUser(t, newUser, u1)
}
type user struct {
ID int `gorm:"primaryKey;->:false"`
Name string
Age int
}
var u2 user
if results := DB.Create(&u2); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
if u2.ID != 0 {
t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID)
}
}

View File

@ -206,9 +206,9 @@ func TestDeleteSliceWithAssociations(t *testing.T) {
}
}
// only sqlite, postgres, sqlserver support returning
// only sqlite, postgres, gaussdb, sqlserver support returning
func TestSoftDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
return
}
@ -233,7 +233,7 @@ func TestSoftDeleteReturning(t *testing.T) {
}
func TestDeleteReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
return
}

View File

@ -39,7 +39,7 @@ func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}
@ -81,7 +81,7 @@ func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) {
t.Fatalf("failed to connect database, got error %v", err)
}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "mysql": true, "sqlserver": true}
dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true}
if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) {
return
}

248
tests/gaussdb_test.go Normal file
View File

@ -0,0 +1,248 @@
package tests_test
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
func TestGaussDBReturningIDWhichHasStringType(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Yasuo struct {
// TODO: function gen_random_uuid() does not exist
ID string `gorm:"default:gen_random_uuid()"`
Name string
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Yasuo{})
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
yasuo := Yasuo{Name: "jinzhu"}
if err := DB.Create(&yasuo).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
if yasuo.ID == "" {
t.Fatal("should be able to has ID, but got zero value")
}
var result Yasuo
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
yasuo.Name = "jinzhu1"
if err := DB.Save(&yasuo).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
}
func TestGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Harumph struct {
gorm.Model
Name string `gorm:"check:name_checker,name <> ''"`
// TODO: function gen_random_uuid() does not exist
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
Things pq.StringArray `gorm:"type:text[]"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Harumph{})
if err := DB.AutoMigrate(&Harumph{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
harumph := Harumph{}
if err := DB.Create(&harumph).Error; err == nil {
t.Fatalf("should failed to create data, name can't be blank")
}
harumph = Harumph{Name: "jinzhu"}
if err := DB.Create(&harumph).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
var result Harumph
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
harumph.Name = "jinzhu1"
if err := DB.Save(&harumph).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
DB.Migrator().DropTable("log_usage")
if err := DB.Exec(`
CREATE TABLE public.log_usage (
log_id bigint NOT NULL
);
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
SEQUENCE NAME public.log_usage_log_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1
);
`).Error; err != nil {
t.Fatalf("failed to create table, got error %v", err)
}
columns, err := DB.Migrator().ColumnTypes("log_usage")
if err != nil {
t.Fatalf("failed to get columns, got error %v", err)
}
hasLogID := false
for _, column := range columns {
if column.Name() == "log_id" {
hasLogID = true
autoIncrement, ok := column.AutoIncrement()
if !ok || !autoIncrement {
t.Fatalf("column log_id should be auto incrementment")
}
}
}
if !hasLogID {
t.Fatalf("failed to found column log_id")
}
}
func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil {
t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err)
}
DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories")
DB.AutoMigrate(&Post{}, &Category{})
post := Post{
Title: "Hello World",
Categories: []*Category{
{Title: "Coding"},
{Title: "Golang"},
},
}
if err := DB.Create(&post).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
}
func TestGaussDBOnConstraint(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement")
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
type Thing struct {
gorm.Model
SomeID string
OtherID string
Data string
}
DB.Migrator().DropTable(&Thing{})
DB.Migrator().CreateTable(&Thing{})
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
t.Error(err)
}
thing := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something",
}
DB.Create(&thing)
thing2 := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something else",
}
result := DB.Clauses(clause.OnConflict{
OnConstraint: "some_id_other_id_unique",
UpdateAll: true,
}).Create(&thing2)
if result.Error != nil {
t.Errorf("creating second thing: %v", result.Error)
}
var things []Thing
if err := DB.Find(&things).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
if len(things) > 1 {
t.Errorf("expected 1 thing got more")
}
}
func TestGaussDBAlterColumnDataType(t *testing.T) {
if DB.Dialector.Name() != "gaussdb" {
t.Skip()
}
DB.Migrator().DropTable(&Company{})
DB.AutoMigrate(Company{})
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
t.Fatalf("failed to alter column from string to int, got error %v", err)
}
DB.AutoMigrate(Company{})
}

875
tests/generics_test.go Normal file
View File

@ -0,0 +1,875 @@
package tests_test
import (
"context"
"errors"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"testing"
"github.com/google/uuid"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
func TestGenericsCreate(t *testing.T) {
ctx := context.Background()
user := User{Name: "TestGenericsCreate", Age: 18}
err := gorm.G[User](DB).Create(ctx, &user)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
if user.ID == 0 {
t.Fatalf("no primary key found for %v", user)
}
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.ID != user.ID {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.ID != user.ID {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != user.Name || u.Age != 0 {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != "" || u.Age != user.Age {
t.Errorf("found invalid user, got %v, expect %v", u, user)
}
result := struct {
ID int
Name string
}{}
if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
t.Fatalf("failed to scan user, got error: %v", err)
} else if result.Name != user.Name || uint(result.ID) != user.ID {
t.Errorf("found invalid user, got %v, expect %v", result, user)
}
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
}
}
func TestGenericsCreateInBatches(t *testing.T) {
batch := []User{
{Name: "GenericsCreateInBatches1"},
{Name: "GenericsCreateInBatches2"},
{Name: "GenericsCreateInBatches3"},
}
ctx := context.Background()
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
for _, u := range batch {
if u.ID == 0 {
t.Fatalf("no primary key found for %v", u)
}
}
count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 3 {
t.Errorf("expected 3 records, got %d", count)
}
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
if len(found) != len(batch) {
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
}
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
if len(found) != 2 {
t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
}
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
if len(found) != 1 {
t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
}
}
func TestGenericsExecAndUpdate(t *testing.T) {
ctx := context.Background()
name := "GenericsExec"
if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
t.Fatalf("Exec insert failed: %v", err)
}
u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if u.Name != name || u.ID == 0 {
t.Errorf("found invalid user, got %v", u)
}
name += "Update"
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
if rows != 1 {
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
}
nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if nu.Name != name || u.ID != nu.ID {
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
}
rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
if rows != 1 {
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
}
nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
if err != nil {
t.Fatalf("failed to find user, got error: %v", err)
} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
}
}
func TestGenericsRow(t *testing.T) {
ctx := context.Background()
user := User{Name: "GenericsRow"}
if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
t.Fatalf("Create failed: %v", err)
}
row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
var name string
if err := row.Scan(&name); err != nil {
t.Fatalf("Row scan failed: %v", err)
}
if name != user.Name {
t.Errorf("expected %s, got %s", user.Name, name)
}
user2 := User{Name: "GenericsRow2"}
if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
t.Fatalf("Create failed: %v", err)
}
rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
if err != nil {
t.Fatalf("Rows failed: %v", err)
}
count := 0
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
t.Fatalf("rows.Scan failed: %v", err)
}
count++
}
if count != 2 {
t.Errorf("expected 2 rows, got %d", count)
}
}
func TestGenericsDelete(t *testing.T) {
ctx := context.Background()
u := User{Name: "GenericsDelete"}
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
t.Fatalf("Create failed: %v", err)
}
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if rows != 1 {
t.Errorf("expected 1 row deleted, got %d", rows)
}
_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
if err != gorm.ErrRecordNotFound {
t.Fatalf("User after delete failed: %v", err)
}
}
func TestGenericsFindInBatches(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: "GenericsFindBatchA"},
{Name: "GenericsFindBatchB"},
{Name: "GenericsFindBatchC"},
{Name: "GenericsFindBatchD"},
{Name: "GenericsFindBatchE"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
total := 0
err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
if len(chunk) > 2 {
t.Errorf("batch size exceed 2: got %d", len(chunk))
}
total += len(chunk)
return nil
})
if err != nil {
t.Fatalf("FindInBatches failed: %v", err)
}
if total != len(users) {
t.Errorf("expected total %d, got %d", len(users), total)
}
}
func TestGenericsScopes(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
if err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
filterName1 := func(stmt *gorm.Statement) {
stmt.Where("name = ?", "GenericsScopes1")
}
results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
if err != nil {
t.Fatalf("Scopes failed: %v", err)
}
if len(results) != 1 || results[0].Name != "GenericsScopes1" {
t.Fatalf("Scopes expected 1, got %d", len(results))
}
notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
if len(notResult) != 2 {
t.Fatalf("expected 2 results, got %d", len(notResult))
} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
}
orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
if len(orResult) != 2 {
t.Fatalf("expected 2 results, got %d", len(notResult))
} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
}
}
func TestGenericsJoins(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
// Inner JOIN + WHERE
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Inner JOIN + WHERE with map
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
db.Where(map[string]any{"name": u.Company.Name})
return nil
}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN w/o WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Left JOIN + Alias WHERE
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
}).Where(map[string]any{"name": u.Name}).First(ctx)
if err != nil {
t.Fatalf("Joins failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
// Raw Subquery JOIN + WHERE + Select
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
if joinTable.Name != "t" {
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
}
db.Where("?.name = ?", joinTable, u.Company.Name)
return nil
},
).Where(map[string]any{"name": u2.Name}).First(ctx)
if err != nil {
t.Fatalf("Raw subquery join failed: %v", err)
}
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
}
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
return errors.New("join error")
}).First(ctx)
if err == nil {
t.Fatalf("Joins should got error, but got nil")
}
}
func TestGenericsNestedJoins(t *testing.T) {
users := []User{
{
Name: "generics-nested-joins-1",
Manager: &User{
Name: "generics-nested-joins-manager-1",
Company: Company{
Name: "generics-nested-joins-manager-company-1",
},
NamedPet: &Pet{
Name: "generics-nested-joins-manager-namepet-1",
Toy: Toy{
Name: "generics-nested-joins-manager-namepet-toy-1",
},
},
},
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
},
{
Name: "generics-nested-joins-2",
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
},
}
ctx := context.Background()
db := gorm.G[User](DB)
db.CreateInBatches(ctx, &users, 100)
var userIDs []uint
for _, user := range users {
userIDs = append(userIDs, user.ID)
}
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
Where(map[string]any{"id": userIDs}).Find(ctx)
if err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
} else if len(users2) != len(users) {
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
}
sort.Slice(users2, func(i, j int) bool {
return users2[i].ID > users2[j].ID
})
sort.Slice(users, func(i, j int) bool {
return users[i].ID > users[j].ID
})
for idx, user := range users {
// user
CheckUser(t, user, users2[idx])
if users2[idx].Manager == nil {
t.Fatalf("Failed to load Manager")
}
// manager
CheckUser(t, *user.Manager, *users2[idx].Manager)
// user pet
if users2[idx].NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
// manager pet
if users2[idx].Manager.NamedPet == nil {
t.Fatalf("Failed to load NamedPet")
}
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
}
}
func TestGenericsPreloads(t *testing.T) {
ctx := context.Background()
db := gorm.G[User](DB)
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
names := []string{u.Name, u2.Name, u3.Name}
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
}
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
db.Where("name = ?", u.Company.Name)
return nil
}).Where("name in ?", names).Find(ctx)
if err != nil {
t.Fatalf("Preload failed: %v", err)
}
for _, result := range results {
if result.Name == u.Name {
if result.Company.Name != u.Company.Name {
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
} else if result.Company.Name != "" {
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
}
}
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
return errors.New("preload error")
}).Where("name in ?", names).Find(ctx)
if err == nil {
t.Fatalf("Preload should failed, but got nil")
}
if DB.Dialector.Name() == "mysql" {
// mysql 5.7 doesn't support row_number()
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
return
}
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name desc").LimitPerRecord(5)
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name < result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
db.Order("name").LimitPerRecord(5)
return nil
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
db.Order("name")
return nil
}).Where("name in ?", names).Find(ctx)
for _, result := range results {
if result.Name == u.Name {
if len(result.Pets) != len(u.Pets) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
if len(result.Friends) != len(u.Friends) {
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
}
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
for i := 1; i < len(result.Pets); i++ {
if result.Pets[i-1].Name > result.Pets[i].Name {
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
}
}
}
}
func TestGenericsNestedPreloads(t *testing.T) {
user := *GetUser("generics_nested_preload", Config{Pets: 2})
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
ctx := context.Background()
db := gorm.G[User](DB)
for idx, pet := range user.Pets {
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
}
if err := db.Create(ctx, &user); err != nil {
t.Fatalf("errors happened when create: %v", err)
}
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user2, user)
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
t.Fatalf("failed to nested preload")
}
if DB.Dialector.Name() == "mysql" {
// mysql 5.7 doesn't support row_number()
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
return
}
}
if DB.Dialector.Name() == "sqlserver" {
// sqlserver doesn't support order by in subquery
return
}
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
db.LimitPerRecord(3)
return nil
}).Where(user.ID).Take(ctx)
if err != nil {
t.Errorf("failed to nested preload user")
}
CheckUser(t, user3, user)
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
t.Errorf("failed to nested preload with limit per record")
}
}
func TestGenericsDistinct(t *testing.T) {
ctx := context.Background()
batch := []User{
{Name: "GenericsDistinctDup"},
{Name: "GenericsDistinctDup"},
{Name: "GenericsDistinctUnique"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
if err != nil {
t.Fatalf("Distinct Find failed: %v", err)
}
if len(results) != 2 {
t.Errorf("expected 2 distinct names, got %d", len(results))
}
var names []string
for _, u := range results {
names = append(names, u.Name)
}
sort.Strings(names)
expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
if !reflect.DeepEqual(names, expected) {
t.Errorf("expected names %v, got %v", expected, names)
}
}
func TestGenericsGroupHaving(t *testing.T) {
ctx := context.Background()
batch := []User{
{Name: "GenericsGroupHavingMulti"},
{Name: "GenericsGroupHavingMulti"},
{Name: "GenericsGroupHavingSingle"},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
if err != nil {
t.Fatalf("Group+Having Find failed: %v", err)
}
if len(grouped) != 1 {
t.Errorf("expected 1 group with count>1, got %d", len(grouped))
} else if grouped[0].Name != "GenericsGroupHavingMulti" {
t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
}
}
func TestGenericsSubQuery(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: "GenericsSubquery_1", Age: 10},
{Name: "GenericsSubquery_2", Age: 20},
{Name: "GenericsSubquery_3", Age: 30},
{Name: "GenericsSubquery_4", Age: 40},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
if err != nil {
t.Fatalf("got error: %v", err)
}
if len(results) != 4 {
t.Errorf("Four users should be found, instead found %d", len(results))
}
results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
if err != nil {
t.Fatalf("got error: %v", err)
}
if len(results) != 3 {
t.Errorf("Three users should be found, instead found %d", len(results))
}
}
func TestGenericsUpsert(t *testing.T) {
ctx := context.Background()
lang := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
lang2 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
if err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
}
lang3 := Language{Code: "upsert", Name: "Upsert"}
if err := gorm.G[Language](DB, clause.OnConflict{
Columns: []clause.Column{{Name: "code"}},
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
}).Create(ctx, &lang3); err != nil {
t.Fatalf("failed to upsert, got %v", err)
}
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
t.Errorf("no error should happen when find languages with code, but got %v", err)
} else if len(langs) != 1 {
t.Errorf("should only find only 1 languages, but got %+v", langs)
} else if langs[0].Name != "upsert-new" {
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
}
}
func TestGenericsWithResult(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
result := gorm.WithResult()
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
if err != nil {
t.Errorf("failed to create users WithResult")
}
if result.RowsAffected != 2 {
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
}
}
func TestGenericsReuse(t *testing.T) {
ctx := context.Background()
users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
if err != nil {
t.Errorf("failed to create users")
}
reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
sg := sync.WaitGroup{}
for i := 0; i < 5; i++ {
sg.Add(1)
go func() {
if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
}
if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
}
if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
t.Errorf("failed to find user, got error: %v", err)
} else if len(users) != 2 {
t.Errorf("should find 2 users, but got %d", len(users))
}
sg.Done()
}()
}
sg.Wait()
}
func TestGenericsWithTransaction(t *testing.T) {
ctx := context.Background()
tx := DB.Begin()
if tx.Error != nil {
t.Fatalf("failed to begin transaction: %v", tx.Error)
}
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count != 2 {
t.Errorf("expected 2 records, got %d", count)
}
if err := tx.Rollback().Error; err != nil {
t.Fatalf("failed to rollback transaction: %v", err)
}
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
if err != nil {
t.Fatalf("Count failed: %v", err)
}
if count2 != 0 {
t.Errorf("expected 0 records after rollback, got %d", count2)
}
}
func TestGenericsToSQL(t *testing.T) {
ctx := context.Background()
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
gorm.G[User](tx).Limit(10).Find(ctx)
return tx
})
if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
}
}
func TestGenericsScanUUID(t *testing.T) {
ctx := context.Background()
users := []User{
{Name: uuid.NewString(), Age: 21},
{Name: uuid.NewString(), Age: 22},
{Name: uuid.NewString(), Age: 23},
}
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil {
t.Fatalf("CreateInBatches failed: %v", err)
}
userIds := []uuid.UUID{}
if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 {
t.Fatalf("Scan failed: %v, userids %v", err, userIds)
}
if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name {
t.Fatalf("wrong uuid scanned")
}
}

View File

@ -1,41 +1,40 @@
module gorm.io/gorm/tests
go 1.18
go 1.23.0
require (
github.com/google/uuid v1.6.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.9.0
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.10
gorm.io/driver/sqlite v1.5.6
gorm.io/driver/sqlserver v1.5.4
gorm.io/gorm v1.25.12
github.com/stretchr/testify v1.10.0
gorm.io/driver/gaussdb v0.1.0
gorm.io/driver/mysql v1.6.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/driver/sqlserver v1.6.1
gorm.io/gorm v1.30.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.1 // indirect
github.com/jackc/pgx/v5 v5.7.5 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/microsoft/go-mssqldb v1.7.2 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/microsoft/go-mssqldb v1.9.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.29.0 // indirect
golang.org/x/text v0.20.0 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/text v0.27.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace gorm.io/gorm => ../
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3
replace github.com/microsoft/go-mssqldb => github.com/microsoft/go-mssqldb v1.7.0

View File

@ -419,7 +419,7 @@ func TestJoinsPreload_Issue7013(t *testing.T) {
var entries []User
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Debug().Preload("Manager.Team").
DB.Preload("Manager.Team").
Joins("Manager.Company").
Find(&entries).Error)
})
@ -456,7 +456,7 @@ func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) {
var entries []Building
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Debug().Preload("Owner.Furnitures").
DB.Preload("Owner.Furnitures").
Joins("Owner.Company").
Find(&entries).Error)
})
@ -468,7 +468,7 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
var entries []User
assert.NotPanics(t, func() {
assert.NoError(t,
DB.Debug().Preload("Manager.Team").
DB.Preload("Manager.Team").
Joins("Manager.Company").
Where("1 <> 1").
Find(&entries).Error)

529
tests/lru_test.go Normal file
View File

@ -0,0 +1,529 @@
package tests_test
import (
"crypto/rand"
"fmt"
"gorm.io/gorm/internal/lru"
"math"
"math/big"
"reflect"
"sync"
"testing"
"time"
)
func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) {
lru := lru.NewLRU[string, int](10, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key1", 2)
if value, ok := lru.Get("key1"); !ok || value != 2 {
t.Errorf("Expected value to be updated to 2, got %v", value)
}
}
func TestLRU_Add_NewKey_AddsEntry(t *testing.T) {
lru := lru.NewLRU[string, int](10, nil, time.Hour)
lru.Add("key1", 1)
if value, ok := lru.Get("key1"); !ok || value != 1 {
t.Errorf("Expected key1 to be added with value 1, got %v", value)
}
}
func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) {
lru := lru.NewLRU[string, int](2, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
if _, ok := lru.Get("key1"); ok {
t.Errorf("Expected key1 to be removed, but it still exists")
}
}
func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) {
lru := lru.NewLRU[string, int](0, nil, time.Hour)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
if _, ok := lru.Get("key1"); !ok {
t.Errorf("Expected key1 to exist, but it was evicted")
}
}
func TestLRU_Add_Eviction(t *testing.T) {
lru := lru.NewLRU[string, int](0, nil, time.Second*2)
lru.Add("key1", 1)
lru.Add("key2", 2)
lru.Add("key3", 3)
time.Sleep(time.Second * 3)
if lru.Cap() != 0 {
t.Errorf("Expected lru to be empty, but it was not")
}
}
func BenchmarkLRU_Rand_NoExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, 0)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
trace[i] = getRand(b) % 32768
}
b.ResetTimer()
var hit, miss int
for i := 0; i < 2*b.N; i++ {
if i%2 == 0 {
l.Add(trace[i], trace[i])
} else {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Freq_NoExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, 0)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
if i%2 == 0 {
trace[i] = getRand(b) % 16384
} else {
trace[i] = getRand(b) % 32768
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Add(trace[i], trace[i])
}
var hit, miss int
for i := 0; i < b.N; i++ {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Rand_WithExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
trace[i] = getRand(b) % 32768
}
b.ResetTimer()
var hit, miss int
for i := 0; i < 2*b.N; i++ {
if i%2 == 0 {
l.Add(trace[i], trace[i])
} else {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func BenchmarkLRU_Freq_WithExpire(b *testing.B) {
l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10)
trace := make([]int64, b.N*2)
for i := 0; i < b.N*2; i++ {
if i%2 == 0 {
trace[i] = getRand(b) % 16384
} else {
trace[i] = getRand(b) % 32768
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
l.Add(trace[i], trace[i])
}
var hit, miss int
for i := 0; i < b.N; i++ {
if _, ok := l.Get(trace[i]); ok {
hit++
} else {
miss++
}
}
b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss))
}
func TestLRUNoPurge(t *testing.T) {
lc := lru.NewLRU[string, string](10, nil, 0)
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok := lc.Peek("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if !lc.Contains("key1") {
t.Fatalf("should contain key1")
}
if lc.Contains("key2") {
t.Fatalf("should not contain key2")
}
v, ok = lc.Peek("key2")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
t.Fatalf("value differs from expected")
}
if lc.Resize(0) != 0 {
t.Fatalf("evicted count differs from expected")
}
if lc.Resize(2) != 0 {
t.Fatalf("evicted count differs from expected")
}
lc.Add("key2", "val2")
if lc.Resize(1) != 1 {
t.Fatalf("evicted count differs from expected")
}
}
func TestLRUEdgeCases(t *testing.T) {
lc := lru.NewLRU[string, *string](2, nil, 0)
// Adding a nil value
lc.Add("key1", nil)
value, exists := lc.Get("key1")
if value != nil || !exists {
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
}
// Adding an entry with the same key but different value
newVal := "val1"
lc.Add("key1", &newVal)
value, exists = lc.Get("key1")
if value != &newVal || !exists {
t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists)
}
}
func TestLRU_Values(t *testing.T) {
lc := lru.NewLRU[string, string](3, nil, 0)
lc.Add("key1", "val1")
lc.Add("key2", "val2")
lc.Add("key3", "val3")
values := lc.Values()
if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) {
t.Fatalf("values differs from expected")
}
}
// func TestExpirableMultipleClose(_ *testing.T) {
// lc :=lru.NewLRU[string, string](10, nil, 0)
// lc.Close()
// // should not panic
// lc.Close()
// }
func TestLRUWithPurge(t *testing.T) {
var evicted []string
lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond)
k, v, ok := lc.GetOldest()
if k != "" {
t.Fatalf("should be empty")
}
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
lc.Add("key1", "val1")
time.Sleep(100 * time.Millisecond) // not enough to expire
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Get("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
time.Sleep(200 * time.Millisecond) // expire
v, ok = lc.Get("key1")
if ok {
t.Fatalf("should be false")
}
if v != "" {
t.Fatalf("should be nil")
}
if lc.Len() != 0 {
t.Fatalf("length differs from expected")
}
if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) {
t.Fatalf("value differs from expected")
}
// add new entry
lc.Add("key2", "val2")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
k, v, ok = lc.GetOldest()
if k != "key2" {
t.Fatalf("value differs from expected")
}
if v != "val2" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
}
func TestLRUWithPurgeEnforcedBySize(t *testing.T) {
lc := lru.NewLRU[string, string](10, nil, time.Hour)
for i := 0; i < 100; i++ {
i := i
lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i))
v, ok := lc.Get(fmt.Sprintf("key%d", i))
if v != fmt.Sprintf("val%d", i) {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if lc.Len() > 20 {
t.Fatalf("length should be less than 20")
}
}
if lc.Len() != 10 {
t.Fatalf("length differs from expected")
}
}
func TestLRUConcurrency(t *testing.T) {
lc := lru.NewLRU[string, string](0, nil, 0)
wg := sync.WaitGroup{}
wg.Add(1000)
for i := 0; i < 1000; i++ {
go func(i int) {
lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10))
wg.Done()
}(i)
}
wg.Wait()
if lc.Len() != 100 {
t.Fatalf("length differs from expected")
}
}
func TestLRUInvalidateAndEvict(t *testing.T) {
var evicted int
lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0)
lc.Add("key1", "val1")
lc.Add("key2", "val2")
val, ok := lc.Get("key1")
if !ok {
t.Fatalf("should be true")
}
if val != "val1" {
t.Fatalf("value differs from expected")
}
if evicted != 0 {
t.Fatalf("value differs from expected")
}
lc.Remove("key1")
if evicted != 1 {
t.Fatalf("value differs from expected")
}
val, ok = lc.Get("key1")
if val != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
}
func TestLoadingExpired(t *testing.T) {
lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5)
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok := lc.Peek("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
v, ok = lc.Get("key1")
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
for {
result, ok := lc.Get("key1")
if ok && result == "" {
t.Fatalf("ok should return a result")
}
if !ok {
break
}
}
time.Sleep(time.Millisecond * 100) // wait for expiration reaper
if lc.Len() != 0 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Peek("key1")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
v, ok = lc.Get("key1")
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
}
func TestLRURemoveOldest(t *testing.T) {
lc := lru.NewLRU[string, string](2, nil, 0)
if lc.Cap() != 2 {
t.Fatalf("expect cap is 2")
}
k, v, ok := lc.RemoveOldest()
if k != "" {
t.Fatalf("should be empty")
}
if v != "" {
t.Fatalf("should be empty")
}
if ok {
t.Fatalf("should be false")
}
ok = lc.Remove("non_existent")
if ok {
t.Fatalf("should be false")
}
lc.Add("key1", "val1")
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
v, ok = lc.Get("key1")
if !ok {
t.Fatalf("should be true")
}
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
lc.Add("key2", "val2")
if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 2 {
t.Fatalf("length differs from expected")
}
k, v, ok = lc.RemoveOldest()
if k != "key1" {
t.Fatalf("value differs from expected")
}
if v != "val1" {
t.Fatalf("value differs from expected")
}
if !ok {
t.Fatalf("should be true")
}
if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) {
t.Fatalf("value differs from expected")
}
if lc.Len() != 1 {
t.Fatalf("length differs from expected")
}
}
func getRand(tb testing.TB) int64 {
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
tb.Fatal(err)
}
return out.Int64()
}

View File

@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"math/rand"
"os"
"reflect"
"strconv"
"strings"
@ -13,11 +12,11 @@ import (
"time"
"github.com/stretchr/testify/assert"
"gorm.io/driver/gaussdb"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
@ -84,8 +83,8 @@ func TestMigrate(t *testing.T) {
}
}
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
return
}
@ -184,7 +183,94 @@ func TestAutoMigrateNullable(t *testing.T) {
}
func TestSmartMigrateColumn(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
fullSupported := map[string]bool{"mysql": true, "postgres": true, "gaussdb": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
Name string
Salary float64
Birthday time.Time `gorm:"precision:4"`
}
DB.Migrator().DropTable(&UserMigrateColumn{})
DB.AutoMigrate(&UserMigrateColumn{})
type UserMigrateColumn2 struct {
ID uint
Name string `gorm:"size:128"`
Salary float64 `gorm:"precision:2"`
Birthday time.Time `gorm:"precision:2"`
NameIgnoreMigration string `gorm:"size:100"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("salary's precision should be 2, but got %v %v", precision, o)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
}
}
type UserMigrateColumn3 struct {
ID uint
Name string `gorm:"size:256"`
Salary float64 `gorm:"precision:3"`
Birthday time.Time `gorm:"precision:3"`
NameIgnoreMigration string `gorm:"size:128;-:migration"`
}
if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{})
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
for _, columnType := range columnTypes {
switch columnType.Name() {
case "name":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 {
t.Fatalf("name's length should be 128, but got %v", length)
}
case "salary":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("salary's precision should be 2, but got %v", precision)
}
case "birthday":
if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 {
t.Fatalf("birthday's precision should be 2, but got %v", precision)
}
case "name_ignore_migration":
if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 {
t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length)
}
}
}
}
func TestSmartMigrateColumnGaussDB(t *testing.T) {
fullSupported := map[string]bool{"mysql": true, "gaussdb": true}[DB.Dialector.Name()]
type UserMigrateColumn struct {
ID uint
@ -852,7 +938,7 @@ func TestMigrateColumnOrder(t *testing.T) {
// https://github.com/go-gorm/gorm/issues/5047
func TestMigrateSerialColumn(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
return
}
@ -1011,6 +1097,42 @@ func TestPrimarykeyID(t *testing.T) {
}
}
func TestPrimarykeyIDGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support uuid-ossp plugin (SQLSTATE 58P01)")
if DB.Dialector.Name() != "gaussdb" {
return
}
type MissPKLanguage struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
Name string
}
type MissPKUser struct {
ID string `gorm:"type:uuid;default:uuid_generate_v4()"`
MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"`
}
var err error
err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("DropTable err:%v", err)
}
// TODO: ERROR: could not open extension control file: No such file or directory (SQLSTATE 58P01)
DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`)
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
// patch
err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{})
if err != nil {
t.Fatalf("AutoMigrate err:%v", err)
}
}
func TestCurrentTimestamp(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
return
@ -1211,35 +1333,24 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
}
}
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
// TODO: ERROR: must have at least one column (SQLSTATE 0A000)
func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) {
t.Skipf("This test case skipped, because of gaussdb not support creaing empty table(SQLSTATE 0A000)")
if DB.Dialector.Name() != "gaussdb" {
return
}
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
db, err := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
if err != nil {
t.Errorf("Open err:%v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" {
db.Logger = db.Logger.LogMode(logger.Silent)
}
type Object1 struct {
ID uint
}
type Object1 struct{}
type Object2 struct {
ID uint
Field1 int `gorm:"type:int8"`
Field1 string
}
type Object3 struct {
ID uint
Field1 int `gorm:"type:int4"`
}
type Object4 struct {
ID uint
Field2 int
Field2 string
}
db.Migrator().DropTable("objects")
@ -1247,63 +1358,16 @@ func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Create(&Object1{}).Error
if err != nil {
t.Errorf("create err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object2{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AlterColumn
err = db.Table("objects").AutoMigrate(&Object3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object3{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object4{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
}
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
@ -1346,7 +1410,7 @@ func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
}
func TestMigrateArrayTypeModel(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
return
}
@ -1668,8 +1732,8 @@ func TestMigrateView(t *testing.T) {
}
}
func TestMigrateExistingBoolColumnPG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" {
return
}
@ -1987,3 +2051,114 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) {
}
}
}
func testAutoMigrateDecimal(t *testing.T, model1, model2 any) []string {
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE ") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if decimal is not change: sql: %s", sql)
}
},
}
session := DB.Session(&gorm.Session{Logger: tracer})
DB.Migrator().DropTable(model1)
var modifySql []string
if err := session.AutoMigrate(model1); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
if err := session.AutoMigrate(model1); err != nil {
t.Fatalf("failed to auto migrate, got error: %v", err)
}
tracer2 := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
modifySql = append(modifySql, sql)
},
}
session2 := DB.Session(&gorm.Session{Logger: tracer2})
err := session2.Table("migrate_decimal_columns").Migrator().AutoMigrate(model2)
if err != nil {
t.Fatalf("failed to get column types, got error: %v", err)
}
return modifySql
}
func decimalColumnsTest[T, T2 any](t *testing.T, expectedSql []string) {
var t1 T
var t2 T2
modSql := testAutoMigrateDecimal(t, t1, t2)
var alterSQL []string
for _, sql := range modSql {
if strings.HasPrefix(sql, "ALTER TABLE ") {
alterSQL = append(alterSQL, sql)
}
}
if len(alterSQL) != 3 {
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
}
for i := range alterSQL {
if alterSQL[i] != expectedSql[i] {
t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL)
}
}
}
func TestAutoMigrateDecimal(t *testing.T) {
if DB.Dialector.Name() == "sqlserver" { // database/sql will replace numeric to decimal. so only support decimal.
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
}
expectedSql := []string{
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" decimal(8) NOT NULL`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" decimal(9,1) NOT NULL`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" decimal(9,2) NOT NULL`,
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
} else if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:numeric(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:numeric(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:numeric(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:numeric(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:numeric(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:numeric(9,2);not null" json:"recid3"`
}
expectedSql := []string{
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" TYPE numeric(8) USING "recid1"::numeric(8)`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" TYPE numeric(9,1) USING "recid2"::numeric(9,1)`,
`ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" TYPE numeric(9,2) USING "recid3"::numeric(9,2)`,
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
} else if DB.Dialector.Name() == "mysql" {
type MigrateDecimalColumn struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"`
}
type MigrateDecimalColumn2 struct {
RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"`
RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"`
RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"`
}
expectedSql := []string{
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid1` decimal(8) NOT NULL",
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid2` decimal(9,1) NOT NULL",
"ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid3` decimal(9,2) NOT NULL",
}
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
}
}

View File

@ -41,7 +41,7 @@ func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "postgres" {
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" || name == "gaussdb" {
stmt := gorm.Statement{DB: DB}
stmt.Parse(&Blog{})
stmt.Schema.LookUpField("ID").Unique = true
@ -142,6 +142,9 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
if name := DB.Dialector.Name(); name == "postgres" {
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
}
if name := DB.Dialector.Name(); name == "gaussdb" {
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
@ -264,10 +267,14 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "postgres" {
if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" {
t.Skip("skip postgres due to it only allow unique constraint matching given keys")
}
if name := DB.Dialector.Name(); name == "gaussdb" {
t.Skip("skip gaussdb due to it only allow unique constraint matching given keys")
}
DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags")
if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil {
t.Fatalf("Failed to auto migrate, got error: %v", err)
@ -332,7 +339,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
DB.Model(&blog2).Association("LocaleTags").Find(&tags)
if !compareTags(tags, []string{"tag4"}) {
t.Fatalf("Should find 1 tags for EN Blog")
t.Fatalf("Should find 1 tags for EN Blog, but got %v", tags)
}
// Replace

View File

@ -696,6 +696,10 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment")
}
if name := DB.Dialector.Name(); name == "mysql" {
t.Skip("skip mysql due to it only allow unique constraint matching given keys")
}
type (
Level1 struct {
ID uint `gorm:"primary_key;"`

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
@ -92,6 +91,65 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit()
}
func TestPreparedStmtLruFromTransaction(t *testing.T) {
db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second})
tx := db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
if err := tx.Error; err != nil {
t.Errorf("Failed to start transaction, got error %v\n", err)
}
if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil {
tx.Rollback()
t.Errorf("Failed to run one transaction, got error %v\n", err)
}
if err := tx.Commit().Error; err != nil {
t.Errorf("Failed to commit transaction, got error %v\n", err)
}
if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2 := db.Begin()
if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 {
t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected)
}
tx2.Commit()
// Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type.
// If the conversion is successful, ok will be true and conn will be the converted object;
// otherwise, ok will be false and conn will be nil.
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
// Get the number of statement keys stored in the PreparedStmtDB.
lens := len(conn.Stmts.Keys())
// Check if the number of stored statement keys is 0.
if lens == 0 {
// If the number is 0, it means there are no statements stored in the LRU cache.
// The test fails and an error message is output.
t.Fatalf("lru should not be empty")
}
// Wait for 40 seconds to give the statements in the cache enough time to expire.
time.Sleep(time.Second * 40)
// Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type.
AssertEqual(t, ok, true)
// Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds.
// If it is not 0, it means the statements in the cache have not expired as expected.
AssertEqual(t, len(conn.Stmts.Keys()), 0)
}
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
@ -117,9 +175,9 @@ func TestPreparedStmtDeadlock(t *testing.T) {
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
AssertEqual(t, len(conn.Stmts.Keys()), 2)
for _, stmt := range conn.Stmts.Keys() {
if stmt == "" {
t.Fatalf("stmt cannot bee nil")
}
}
@ -143,10 +201,10 @@ func TestPreparedStmtInTransaction(t *testing.T) {
}
}
func TestPreparedStmtReset(t *testing.T) {
func TestPreparedStmtClose(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
user := *GetUser("prepared_stmt_close", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
@ -155,16 +213,16 @@ func TestPreparedStmtReset(t *testing.T) {
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
if len(pdb.Stmts.Keys()) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Close()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
if len(pdb.Stmts.Keys()) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}
@ -174,10 +232,10 @@ func isUsingClosedConnError(err error) bool {
return err.Error() == "sql: statement is closed"
}
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
func TestPreparedStmtConcurrentReset(t *testing.T) {
name := "prepared_stmt_concurrent_reset"
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
@ -220,7 +278,7 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
go func() {
defer wg.Done()
<-writerFinish
pdb.Reset()
pdb.Close()
}()
wg.Wait()
@ -229,88 +287,3 @@ func TestPreparedStmtConcurrentReset(t *testing.T) {
t.Fatalf("should is a unexpected error")
}
}
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
// for example: one goroutine found error and just close the database, and others are executing SQL
// this test making sure that the gorm would not get a Segmentation Fault,
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
// and all of the goroutine must got gorm.ErrInvalidDB after database close
func TestPreparedStmtConcurrentClose(t *testing.T) {
name := "prepared_stmt_concurrent_close"
user := *GetUser(name, Config{})
createTx := DB.Session(&gorm.Session{}).Create(&user)
if createTx.Error != nil {
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
}
// create a new connection to keep away from other tests
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
if err != nil {
t.Fatalf("failed to open test connection due to %s", err)
}
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
loopCount := 100
var wg sync.WaitGroup
var lastErr error
closeValid := make(chan struct{}, loopCount)
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
var lastRunIndex int
var closeFinishedAt int64
wg.Add(1)
go func(id uint) {
defer wg.Done()
defer close(closeValid)
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
if lastRunIndex == closeStartIdx {
closeValid <- struct{}{}
}
var tmp User
now := time.Now().UnixNano()
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
if err == nil {
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
lastErr = errors.New("must got error after database closed")
break
}
continue
}
lastErr = err
break
}
}(user.ID)
wg.Add(1)
go func() {
defer wg.Done()
for range closeValid {
for i := 0; i < loopCount; i++ {
pdb.Close() // the Close method must can be call multiple times
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
}
}
}()
wg.Wait()
var tmp User
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
if err != gorm.ErrInvalidDB {
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
}
// must be error
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
}
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
}
if pdb.Stmts != nil {
t.Fatalf("stmts must be nil")
}
}

View File

@ -632,6 +632,21 @@ func TestOr(t *testing.T) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
sub := dryDB.Clauses(clause.Where{
Exprs: []clause.Expression{
clause.OrConditions{
Exprs: []clause.Expression{
clause.Expr{SQL: "role = ?", Vars: []interface{}{"super_admin"}},
clause.Expr{SQL: "role = ?", Vars: []interface{}{"admin"}},
},
},
},
})
result = dryDB.Where(sub).Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
}
result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{})
if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) {
t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String())
@ -1112,6 +1127,10 @@ func TestSearchWithMap(t *testing.T) {
DB.First(&user, map[string]interface{}{"name": users[0].Name})
CheckUser(t, user, users[0])
user = User{}
DB.First(&user, map[string]interface{}{"users.name": users[0].Name})
CheckUser(t, user, users[0])
user = User{}
DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user)
CheckUser(t, user, users[1])

View File

@ -45,7 +45,7 @@ type SerializerPostgresStruct struct {
func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" }
func adaptorSerializerModel(s *SerializerStruct) interface{} {
if DB.Dialector.Name() == "postgres" {
if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" {
sps := SerializerPostgresStruct(*s)
return &sps
}

View File

@ -487,7 +487,7 @@ func replaceQuoteInSQL(sql string) string {
// convert dialect special quote into double quote
switch DB.Dialector.Name() {
case "postgres":
case "postgres", "gaussdb":
sql = strings.ReplaceAll(sql, `"`, `"`)
case "mysql", "sqlite":
sql = strings.ReplaceAll(sql, "`", `"`)

View File

@ -5,6 +5,7 @@ import (
"sync"
"testing"
"gorm.io/driver/gaussdb"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
@ -251,6 +252,82 @@ func TestPostgresTableWithIdentifierLength(t *testing.T) {
})
}
func TestGaussDBTableWithIdentifierLength(t *testing.T) {
if DB.Dialector.Name() != "gaussdb" {
return
}
type LongString struct {
ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"`
}
t.Run("default", func(t *testing.T) {
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("naming strategy", func(t *testing.T) {
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
NamingStrategy: schema.NamingStrategy{},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if len(key) != 63 {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
t.Run("namer", func(t *testing.T) {
uname := "custom_unique_name"
db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{
NamingStrategy: mockUniqueNamingStrategy{
UName: uname,
},
})
user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy)
if err != nil {
t.Fatalf("failed to parse user unique, got error %v", err)
}
constraints := user.ParseUniqueConstraints()
if len(constraints) != 1 {
t.Fatalf("failed to find unique constraint, got %v", constraints)
}
for key := range constraints {
if key != uname {
t.Errorf("failed to find unique constraint, got %v", constraints)
}
}
})
}
type mockUniqueNamingStrategy struct {
UName string
schema.NamingStrategy

View File

@ -1,6 +1,6 @@
#!/bin/bash -e
dialects=("sqlite" "mysql" "postgres" "sqlserver" "tidb")
dialects=("sqlite" "mysql" "postgres" "gaussdb" "sqlserver" "tidb")
if [[ $(pwd) == *"gorm/tests"* ]]; then
cd ..

View File

@ -8,6 +8,7 @@ import (
"path/filepath"
"time"
"gorm.io/driver/gaussdb"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@ -21,6 +22,7 @@ var DB *gorm.DB
var (
mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai"
gaussdbDSN = "user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai"
sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master"
tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local"
)
@ -65,6 +67,15 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
DSN: dbDSN,
PreferSimpleProtocol: true,
}), cfg)
case "gaussdb":
log.Println("testing gaussdb...")
if dbDSN == "" {
dbDSN = gaussdbDSN
}
db, err = gorm.Open(gaussdb.New(gaussdb.Config{
DSN: dbDSN,
PreferSimpleProtocol: true,
}), cfg)
case "sqlserver":
// go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest
// SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
return tx2.Scan(&User{}).Error
})
})
if err != nil {
t.Error(err)
}
@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
})
})
if err != nil {
t.Error(err)
}
}
func TestTransactionWithDefaultTimeout(t *testing.T) {
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
if err != nil {
t.Fatalf("failed to connect database, got error %v", err)
}
tx := db.Begin()
time.Sleep(3 * time.Second)
if err = tx.Find(&User{}).Error; err == nil {
t.Errorf("should return error when transaction timeout, got error %v", err)
}
}

View File

@ -765,9 +765,9 @@ func TestSaveWithPrimaryValue(t *testing.T) {
}
}
// only sqlite, postgres, sqlserver support returning
// only sqlite, postgres, gaussdb, sqlserver support returning
func TestUpdateReturning(t *testing.T) {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" {
return
}
@ -883,9 +883,9 @@ func TestSaveWithHooks(t *testing.T) {
}
}
// only postgres, sqlserver, sqlite support update from
// only postgres, gaussdb, sqlserver, sqlite support update from
func TestUpdateFrom(t *testing.T) {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" {
if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" {
return
}