Implement Generics API (#7424)
* Implement Generics API * Add more generics tests * Add more tests and Take method * use delayed‑ops pipeline for generics API * fix generics tests for mysql * Support SubQuery for Generics * Add clause.JoinTable helper method * Fix golangci-lint error * Complete the design and implementation of generic version Join * improve generics version Joins support * allow configuring select/omit columns for joins via subqueries * finish generic version Preload * handle error of generics Joins/Preload * fix tests * Add LimitPerRecord for generic version Preload * fix tests for mysql 5.7 * test for nested generic version Join/Preload * Add WithResult support for generics API * test reuse generics db conditions * fix data race * remove ExampleLRU test * Add default transaction timeout support * fix test
This commit is contained in:
parent
751c1d6b45
commit
c44405a25b
@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}()
|
}()
|
||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
|
|
||||||
if db.RowsAffected == 0 {
|
if db.RowsAffected == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
ok, mode := hasReturning(db, supportReturning)
|
ok, mode := hasReturning(db, supportReturning)
|
||||||
if !ok {
|
if !ok {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
|
|
||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) {
|
|||||||
|
|
||||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||||
|
|
||||||
if len(values) != 0 {
|
if len(values) != 0 {
|
||||||
|
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
|
||||||
|
|
||||||
for _, cond := range conds {
|
for _, cond := range conds {
|
||||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||||
tx = fc(tx)
|
tx = fc(tx)
|
||||||
@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
if len(inlineConds) > 0 {
|
||||||
|
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
|
|||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
}()
|
}()
|
||||||
gorm.Scan(rows, db, 0)
|
gorm.Scan(rows, db, 0)
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
specifiedRelationsName := make(map[string]interface{})
|
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
|
||||||
for _, join := range db.Statement.Joins {
|
for _, join := range db.Statement.Joins {
|
||||||
if db.Statement.Schema != nil {
|
if db.Statement.Schema != nil {
|
||||||
var isRelations bool // is relations or raw sql
|
var isRelations bool // is relations or raw sql
|
||||||
@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
nestedJoinNames := strings.Split(join.Name, ".")
|
nestedJoinNames := strings.Split(join.Name, ".")
|
||||||
if len(nestedJoinNames) > 1 {
|
if len(nestedJoinNames) > 1 {
|
||||||
isNestedJoin := true
|
isNestedJoin := true
|
||||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||||
for _, relname := range nestedJoinNames {
|
for _, relname := range nestedJoinNames {
|
||||||
// incomplete match, only treated as raw sql
|
// incomplete match, only treated as raw sql
|
||||||
if relation, ok = currentRelations[relname]; ok {
|
if relation, ok = currentRelations[relname]; ok {
|
||||||
gussNestedRelations = append(gussNestedRelations, relation)
|
guessNestedRelations = append(guessNestedRelations, relation)
|
||||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||||
} else {
|
} else {
|
||||||
isNestedJoin = false
|
isNestedJoin = false
|
||||||
@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
|
|
||||||
if isNestedJoin {
|
if isNestedJoin {
|
||||||
isRelations = true
|
isRelations = true
|
||||||
relations = gussNestedRelations
|
relations = guessNestedRelations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRelations {
|
if isRelations {
|
||||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||||
tableAliasName := relation.Name
|
|
||||||
if parentTableName != clause.CurrentTable {
|
|
||||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
|
||||||
}
|
|
||||||
|
|
||||||
columnStmt := gorm.Statement{
|
columnStmt := gorm.Statement{
|
||||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||||
Selects: join.Selects, Omits: join.Omits,
|
Selects: join.Selects, Omits: join.Omits,
|
||||||
@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if join.Expression != nil {
|
||||||
|
return clause.Join{
|
||||||
|
Type: join.JoinType,
|
||||||
|
Expression: join.Expression,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
exprs := make([]clause.Expression, len(relation.References))
|
exprs := make([]clause.Expression, len(relation.References))
|
||||||
for idx, ref := range relation.References {
|
for idx, ref := range relation.References {
|
||||||
if ref.OwnPrimaryKey {
|
if ref.OwnPrimaryKey {
|
||||||
@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
parentTableName := clause.CurrentTable
|
parentTableName := clause.CurrentTable
|
||||||
for _, rel := range relations {
|
for idx, rel := range relations {
|
||||||
// joins table alias like "Manager, Company, Manager__Company"
|
// joins table alias like "Manager, Company, Manager__Company"
|
||||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
curAliasName := rel.Name
|
||||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
if parentTableName != clause.CurrentTable {
|
||||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
|
||||||
specifiedRelationsName[nestedAlias] = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if parentTableName != clause.CurrentTable {
|
if _, ok := specifiedRelationsName[curAliasName]; !ok {
|
||||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
aliasName := curAliasName
|
||||||
} else {
|
if idx == len(relations)-1 && join.Alias != "" {
|
||||||
parentTableName = rel.Name
|
aliasName = join.Alias
|
||||||
|
}
|
||||||
|
|
||||||
|
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
|
||||||
|
specifiedRelationsName[curAliasName] = aliasName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parentTableName = curAliasName
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||||
|
@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
gorm.Scan(rows, db, mode)
|
gorm.Scan(rows, db, mode)
|
||||||
db.Statement.Dest = dest
|
db.Statement.Dest = dest
|
||||||
db.AddError(rows.Close())
|
db.AddError(rows.Close())
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||||
@ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
if db.AddError(err) == nil {
|
if db.AddError(err) == nil {
|
||||||
db.RowsAffected, _ = result.RowsAffected()
|
db.RowsAffected, _ = result.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.Statement.Result != nil {
|
||||||
|
db.Statement.Result.Result = result
|
||||||
|
db.Statement.Result.RowsAffected = db.RowsAffected
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
|||||||
// Unscoped allows queries to include records marked as deleted,
|
// Unscoped allows queries to include records marked as deleted,
|
||||||
// overriding the soft deletion behavior.
|
// overriding the soft deletion behavior.
|
||||||
// Example:
|
// Example:
|
||||||
// var users []User
|
//
|
||||||
// db.Unscoped().Find(&users)
|
// var users []User
|
||||||
// // Retrieves all users, including deleted ones.
|
// db.Unscoped().Find(&users)
|
||||||
|
// // Retrieves all users, including deleted ones.
|
||||||
func (db *DB) Unscoped() (tx *DB) {
|
func (db *DB) Unscoped() (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Unscoped = true
|
tx.Statement.Unscoped = true
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package clause
|
package clause
|
||||||
|
|
||||||
|
import "gorm.io/gorm/utils"
|
||||||
|
|
||||||
type JoinType string
|
type JoinType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -9,6 +11,30 @@ const (
|
|||||||
RightJoin JoinType = "RIGHT"
|
RightJoin JoinType = "RIGHT"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type JoinTarget struct {
|
||||||
|
Type JoinType
|
||||||
|
Association string
|
||||||
|
Subquery Expression
|
||||||
|
Table string
|
||||||
|
}
|
||||||
|
|
||||||
|
func Has(name string) JoinTarget {
|
||||||
|
return JoinTarget{Type: InnerJoin, Association: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinType) Association(name string) JoinTarget {
|
||||||
|
return JoinTarget{Type: jt, Association: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
|
||||||
|
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (jt JoinTarget) As(name string) JoinTarget {
|
||||||
|
jt.Table = name
|
||||||
|
return jt
|
||||||
|
}
|
||||||
|
|
||||||
// Join clause for from
|
// Join clause for from
|
||||||
type Join struct {
|
type Join struct {
|
||||||
Type JoinType
|
Type JoinType
|
||||||
@ -18,6 +44,12 @@ type Join struct {
|
|||||||
Expression Expression
|
Expression Expression
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func JoinTable(names ...string) Table {
|
||||||
|
return Table{
|
||||||
|
Name: utils.JoinNestedRelationNames(names),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (join Join) Build(builder Builder) {
|
func (join Join) Build(builder Builder) {
|
||||||
if join.Expression != nil {
|
if join.Expression != nil {
|
||||||
join.Expression.Build(builder)
|
join.Expression.Build(builder)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := tx.Statement.Context
|
||||||
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
|
if db.Config.DefaultTransactionTimeout > 0 {
|
||||||
|
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch beginner := tx.Statement.ConnPool.(type) {
|
switch beginner := tx.Statement.ConnPool.(type) {
|
||||||
case TxBeginner:
|
case TxBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||||
case ConnPoolBeginner:
|
case ConnPoolBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
|
||||||
default:
|
default:
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
605
generics.go
Normal file
605
generics.go
Normal file
@ -0,0 +1,605 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
Result sql.Result
|
||||||
|
RowsAffected int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *result) ModifyStatement(stmt *Statement) {
|
||||||
|
stmt.Result = info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build implements clause.Expression interface
|
||||||
|
func (result) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithResult() *result {
|
||||||
|
return &result{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Interface[T any] interface {
|
||||||
|
Raw(sql string, values ...interface{}) ExecInterface[T]
|
||||||
|
Exec(ctx context.Context, sql string, values ...interface{}) error
|
||||||
|
CreateInterface[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateInterface[T any] interface {
|
||||||
|
ChainInterface[T]
|
||||||
|
Table(name string, args ...interface{}) CreateInterface[T]
|
||||||
|
Create(ctx context.Context, r *T) error
|
||||||
|
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChainInterface[T any] interface {
|
||||||
|
ExecInterface[T]
|
||||||
|
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
|
||||||
|
Where(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
|
Not(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
|
Or(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
|
Limit(offset int) ChainInterface[T]
|
||||||
|
Offset(offset int) ChainInterface[T]
|
||||||
|
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
|
||||||
|
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
|
||||||
|
Select(query string, args ...interface{}) ChainInterface[T]
|
||||||
|
Omit(columns ...string) ChainInterface[T]
|
||||||
|
MapColumns(m map[string]string) ChainInterface[T]
|
||||||
|
Distinct(args ...interface{}) ChainInterface[T]
|
||||||
|
Group(name string) ChainInterface[T]
|
||||||
|
Having(query interface{}, args ...interface{}) ChainInterface[T]
|
||||||
|
Order(value interface{}) ChainInterface[T]
|
||||||
|
|
||||||
|
Build(builder clause.Builder)
|
||||||
|
|
||||||
|
Delete(ctx context.Context) (rowsAffected int, err error)
|
||||||
|
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
|
||||||
|
Updates(ctx context.Context, t T) (rowsAffected int, err error)
|
||||||
|
Count(ctx context.Context, column string) (result int64, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExecInterface[T any] interface {
|
||||||
|
Scan(ctx context.Context, r interface{}) error
|
||||||
|
First(context.Context) (T, error)
|
||||||
|
Last(ctx context.Context) (T, error)
|
||||||
|
Take(context.Context) (T, error)
|
||||||
|
Find(ctx context.Context) ([]T, error)
|
||||||
|
FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error
|
||||||
|
Row(ctx context.Context) *sql.Row
|
||||||
|
Rows(ctx context.Context) (*sql.Rows, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinBuilder interface {
|
||||||
|
Select(...string) JoinBuilder
|
||||||
|
Omit(...string) JoinBuilder
|
||||||
|
Where(query interface{}, args ...interface{}) JoinBuilder
|
||||||
|
Not(query interface{}, args ...interface{}) JoinBuilder
|
||||||
|
Or(query interface{}, args ...interface{}) JoinBuilder
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreloadBuilder interface {
|
||||||
|
Select(...string) PreloadBuilder
|
||||||
|
Omit(...string) PreloadBuilder
|
||||||
|
Where(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Not(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Or(query interface{}, args ...interface{}) PreloadBuilder
|
||||||
|
Limit(offset int) PreloadBuilder
|
||||||
|
Offset(offset int) PreloadBuilder
|
||||||
|
Order(value interface{}) PreloadBuilder
|
||||||
|
LimitPerRecord(num int) PreloadBuilder
|
||||||
|
}
|
||||||
|
|
||||||
|
type op func(*DB) *DB
|
||||||
|
|
||||||
|
func G[T any](db *DB, opts ...clause.Expression) Interface[T] {
|
||||||
|
v := &g[T]{
|
||||||
|
db: db,
|
||||||
|
ops: make([]op, 0, 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts) > 0 {
|
||||||
|
v.ops = append(v.ops, func(db *DB) *DB {
|
||||||
|
return db.Clauses(opts...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
v.createG = &createG[T]{
|
||||||
|
chainG: chainG[T]{
|
||||||
|
execG: execG[T]{g: v},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
type g[T any] struct {
|
||||||
|
*createG[T]
|
||||||
|
db *DB
|
||||||
|
ops []op
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *g[T]) apply(ctx context.Context) *DB {
|
||||||
|
db := g.db
|
||||||
|
if !db.DryRun {
|
||||||
|
db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range g.ops {
|
||||||
|
db = op(db)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] {
|
||||||
|
return execG[T]{g: &g[T]{
|
||||||
|
db: c.db,
|
||||||
|
ops: append(c.ops, func(db *DB) *DB {
|
||||||
|
return db.Raw(sql, values...)
|
||||||
|
}),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error {
|
||||||
|
return c.apply(ctx).Exec(sql, values...).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type createG[T any] struct {
|
||||||
|
chainG[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
|
||||||
|
return createG[T]{c.with(func(db *DB) *DB {
|
||||||
|
return db.Table(name, args...)
|
||||||
|
})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c createG[T]) Create(ctx context.Context, r *T) error {
|
||||||
|
return c.g.apply(ctx).Create(r).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error {
|
||||||
|
return c.g.apply(ctx).CreateInBatches(r, batchSize).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
type chainG[T any] struct {
|
||||||
|
execG[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) getInstance() *DB {
|
||||||
|
var r T
|
||||||
|
return c.g.apply(context.Background()).Model(r).getInstance()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) with(v op) chainG[T] {
|
||||||
|
return chainG[T]{
|
||||||
|
execG: execG[T]{g: &g[T]{
|
||||||
|
db: c.g.db,
|
||||||
|
ops: append(append([]op(nil), c.g.ops...), v),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
for _, fc := range scopes {
|
||||||
|
fc(db.Statement)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Table(name, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Where(query, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Not(query, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Or(query, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Limit(offset int) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Limit(offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Offset(offset int) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Offset(offset)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type joinBuilder struct {
|
||||||
|
db *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *joinBuilder) Select(columns ...string) JoinBuilder {
|
||||||
|
q.db.Select(columns)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *joinBuilder) Omit(columns ...string) JoinBuilder {
|
||||||
|
q.db.Omit(columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
type preloadBuilder struct {
|
||||||
|
limitPerRecord int
|
||||||
|
db *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder {
|
||||||
|
q.db.Where(query, args...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Select(columns ...string) PreloadBuilder {
|
||||||
|
q.db.Select(columns)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder {
|
||||||
|
q.db.Omit(columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Limit(limit int) PreloadBuilder {
|
||||||
|
q.db.Limit(limit)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Offset(offset int) PreloadBuilder {
|
||||||
|
q.db.Offset(offset)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) Order(value interface{}) PreloadBuilder {
|
||||||
|
q.db.Order(value)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder {
|
||||||
|
q.limitPerRecord = num
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
if jt.Table == "" {
|
||||||
|
jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name
|
||||||
|
}
|
||||||
|
|
||||||
|
q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)}
|
||||||
|
if on != nil {
|
||||||
|
if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
j := join{
|
||||||
|
Name: jt.Association,
|
||||||
|
Alias: jt.Table,
|
||||||
|
Selects: q.db.Statement.Selects,
|
||||||
|
Omits: q.db.Statement.Omits,
|
||||||
|
JoinType: jt.Type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||||
|
j.On = &where
|
||||||
|
}
|
||||||
|
|
||||||
|
if jt.Subquery != nil {
|
||||||
|
joinType := j.JoinType
|
||||||
|
if joinType == "" {
|
||||||
|
joinType = clause.LeftJoin
|
||||||
|
}
|
||||||
|
|
||||||
|
if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok {
|
||||||
|
stmt := db.getInstance().Statement
|
||||||
|
if len(j.Selects) == 0 {
|
||||||
|
j.Selects = stmt.Selects
|
||||||
|
}
|
||||||
|
if len(j.Omits) == 0 {
|
||||||
|
j.Omits = stmt.Omits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}}
|
||||||
|
|
||||||
|
if j.On != nil {
|
||||||
|
expr.SQL += " ON ?"
|
||||||
|
expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs})
|
||||||
|
}
|
||||||
|
|
||||||
|
j.Expression = expr
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Statement.Joins = append(db.Statement.Joins, j)
|
||||||
|
sort.Slice(db.Statement.Joins, func(i, j int) bool {
|
||||||
|
return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name
|
||||||
|
})
|
||||||
|
return db
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Select(query, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Omit(columns ...string) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Omit(columns...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.MapColumns(m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Distinct(args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Group(name string) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Group(name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Having(query, args...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Order(value interface{}) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Order(value)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] {
|
||||||
|
return c.with(func(db *DB) *DB {
|
||||||
|
return db.Preload(association, func(tx *DB) *DB {
|
||||||
|
q := preloadBuilder{db: tx.getInstance()}
|
||||||
|
if query != nil {
|
||||||
|
if err := query(&q); err != nil {
|
||||||
|
db.AddError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relation, ok := db.Statement.Schema.Relationships.Relations[association]
|
||||||
|
if !ok {
|
||||||
|
if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 {
|
||||||
|
relationships := db.Statement.Schema.Relationships
|
||||||
|
for _, field := range preloadFields {
|
||||||
|
var ok bool
|
||||||
|
relation, ok = relationships.Relations[field]
|
||||||
|
if ok {
|
||||||
|
relationships = relation.FieldSchema.Relationships
|
||||||
|
} else {
|
||||||
|
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
db.AddError(fmt.Errorf("relation %s not found", association))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.limitPerRecord > 0 {
|
||||||
|
if relation.JoinTable != nil {
|
||||||
|
tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association))
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
refColumns := []clause.Column{}
|
||||||
|
for _, rel := range relation.References {
|
||||||
|
if rel.OwnPrimaryKey {
|
||||||
|
refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(refColumns) != 0 {
|
||||||
|
selectExpr := clause.CommaExpression{}
|
||||||
|
for _, column := range q.db.Statement.Selects {
|
||||||
|
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(selectExpr.Exprs) == 0 {
|
||||||
|
selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
partitionBy := clause.CommaExpression{}
|
||||||
|
for _, column := range refColumns {
|
||||||
|
partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}})
|
||||||
|
}
|
||||||
|
|
||||||
|
rnnColumn := clause.Column{Name: "gorm_preload_rnn"}
|
||||||
|
sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)"
|
||||||
|
vars := []interface{}{partitionBy}
|
||||||
|
if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok {
|
||||||
|
vars = append(vars, orderBy)
|
||||||
|
} else {
|
||||||
|
vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{
|
||||||
|
Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}},
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
vars = append(vars, rnnColumn)
|
||||||
|
|
||||||
|
selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars})
|
||||||
|
|
||||||
|
q.db.Clauses(clause.Select{Expression: selectExpr})
|
||||||
|
|
||||||
|
return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return q.db
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) {
|
||||||
|
r := new(T)
|
||||||
|
res := c.g.apply(ctx).Delete(r)
|
||||||
|
return int(res.RowsAffected), res.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) {
|
||||||
|
var r T
|
||||||
|
res := c.g.apply(ctx).Model(r).Update(name, value)
|
||||||
|
return int(res.RowsAffected), res.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) {
|
||||||
|
res := c.g.apply(ctx).Updates(t)
|
||||||
|
return int(res.RowsAffected), res.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) {
|
||||||
|
var r T
|
||||||
|
err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainG[T]) Build(builder clause.Builder) {
|
||||||
|
subdb := c.getInstance()
|
||||||
|
subdb.Logger = logger.Discard
|
||||||
|
subdb.DryRun = true
|
||||||
|
|
||||||
|
if stmt, ok := builder.(*Statement); ok {
|
||||||
|
if subdb.Statement.SQL.Len() > 0 {
|
||||||
|
var (
|
||||||
|
vars = subdb.Statement.Vars
|
||||||
|
sql = subdb.Statement.SQL.String()
|
||||||
|
)
|
||||||
|
|
||||||
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||||
|
for _, vv := range vars {
|
||||||
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||||
|
bindvar := strings.Builder{}
|
||||||
|
subdb.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||||
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
subdb.Statement.SQL.Reset()
|
||||||
|
subdb.Statement.Vars = stmt.Vars
|
||||||
|
if strings.Contains(sql, "@") {
|
||||||
|
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
} else {
|
||||||
|
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||||
|
subdb.callbacks.Query().Execute(subdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(subdb.Statement.SQL.String())
|
||||||
|
stmt.Vars = subdb.Statement.Vars
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type execG[T any] struct {
|
||||||
|
g *g[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) First(ctx context.Context) (T, error) {
|
||||||
|
var r T
|
||||||
|
err := g.g.apply(ctx).First(&r).Error
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Scan(ctx context.Context, result interface{}) error {
|
||||||
|
var r T
|
||||||
|
err := g.g.apply(ctx).Model(r).Find(&result).Error
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Last(ctx context.Context) (T, error) {
|
||||||
|
var r T
|
||||||
|
err := g.g.apply(ctx).Last(&r).Error
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Take(ctx context.Context) (T, error) {
|
||||||
|
var r T
|
||||||
|
err := g.g.apply(ctx).Take(&r).Error
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Find(ctx context.Context) ([]T, error) {
|
||||||
|
var r []T
|
||||||
|
err := g.g.apply(ctx).Find(&r).Error
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error {
|
||||||
|
var data []T
|
||||||
|
return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error {
|
||||||
|
return fc(data, batch)
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Row(ctx context.Context) *sql.Row {
|
||||||
|
return g.g.apply(ctx).Row()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) {
|
||||||
|
return g.g.apply(ctx).Rows()
|
||||||
|
}
|
6
gorm.go
6
gorm.go
@ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt"
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
|
DefaultTransactionTimeout time.Duration
|
||||||
|
|
||||||
// NamingStrategy tables, columns naming strategy
|
// NamingStrategy tables, columns naming strategy
|
||||||
NamingStrategy schema.Namer
|
NamingStrategy schema.Namer
|
||||||
// FullSaveAssociations full save associations
|
// FullSaveAssociations full save associations
|
||||||
@ -527,7 +529,7 @@ func (db *DB) Use(plugin Plugin) error {
|
|||||||
// .First(&User{})
|
// .First(&User{})
|
||||||
// })
|
// })
|
||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance())
|
||||||
stmt := tx.Statement
|
stmt := tx.Statement
|
||||||
|
|
||||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||||
|
9
scan.go
9
scan.go
@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
@ -244,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
matchedFieldCount[column] = 1
|
matchedFieldCount[column] = 1
|
||||||
}
|
}
|
||||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||||
|
aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1])
|
||||||
|
for _, join := range db.Statement.Joins {
|
||||||
|
if join.Alias == aliasName {
|
||||||
|
names = append(strings.Split(join.Name, "."), names[len(names)-1])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
subNameCount := len(names)
|
subNameCount := len(names)
|
||||||
// nested relation fields
|
// nested relation fields
|
||||||
|
57
statement.go
57
statement.go
@ -47,15 +47,18 @@ type Statement struct {
|
|||||||
attrs []interface{}
|
attrs []interface{}
|
||||||
assigns []interface{}
|
assigns []interface{}
|
||||||
scopes []func(*DB) *DB
|
scopes []func(*DB) *DB
|
||||||
|
Result *result
|
||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
Name string
|
Name string
|
||||||
Conds []interface{}
|
Alias string
|
||||||
On *clause.Where
|
Conds []interface{}
|
||||||
Selects []string
|
On *clause.Where
|
||||||
Omits []string
|
Selects []string
|
||||||
JoinType clause.JoinType
|
Omits []string
|
||||||
|
Expression clause.Expression
|
||||||
|
JoinType clause.JoinType
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatementModifier statement modifier interface
|
// StatementModifier statement modifier interface
|
||||||
@ -205,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||||||
} else {
|
} else {
|
||||||
writer.WriteString("(NULL)")
|
writer.WriteString("(NULL)")
|
||||||
}
|
}
|
||||||
case *DB:
|
case interface{ getInstance() *DB }:
|
||||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
cv := v.getInstance()
|
||||||
if v.Statement.SQL.Len() > 0 {
|
|
||||||
|
subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||||
|
if cv.Statement.SQL.Len() > 0 {
|
||||||
var (
|
var (
|
||||||
vars = subdb.Statement.Vars
|
vars = subdb.Statement.Vars
|
||||||
sql = v.Statement.SQL.String()
|
sql = cv.Statement.SQL.String()
|
||||||
)
|
)
|
||||||
|
|
||||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||||
for _, vv := range vars {
|
for _, vv := range vars {
|
||||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||||
bindvar := strings.Builder{}
|
bindvar := strings.Builder{}
|
||||||
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
cv.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
arg, _ = valuer.Value()
|
arg, _ = valuer.Value()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
curTable := stmt.Table
|
||||||
|
if curTable == "" {
|
||||||
|
curTable = clause.CurrentTable
|
||||||
|
}
|
||||||
|
|
||||||
switch v := arg.(type) {
|
switch v := arg.(type) {
|
||||||
case clause.Expression:
|
case clause.Expression:
|
||||||
conds = append(conds, v)
|
conds = append(conds, v)
|
||||||
@ -351,7 +361,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
column := clause.Column{Name: key, Table: curTable}
|
||||||
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
}
|
}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
keys := make([]string, 0, len(v))
|
keys := make([]string, 0, len(v))
|
||||||
@ -362,12 +373,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||||
|
column := clause.Column{Name: key, Table: curTable}
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
if _, ok := v[key].(driver.Valuer); ok {
|
if _, ok := v[key].(driver.Valuer); ok {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
} else if _, ok := v[key].(Valuer); ok {
|
} else if _, ok := v[key].(Valuer); ok {
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
} else {
|
} else {
|
||||||
// optimize reflect value length
|
// optimize reflect value length
|
||||||
valueLen := reflectValue.Len()
|
valueLen := reflectValue.Len()
|
||||||
@ -376,10 +388,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
values[i] = reflectValue.Index(i).Interface()
|
values[i] = reflectValue.Index(i).Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
conds = append(conds, clause.Eq{Column: column, Value: v[key]})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -406,9 +418,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -420,9 +432,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
if selected || (!restricted && field.Readable) {
|
if selected || (!restricted && field.Readable) {
|
||||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||||
if field.DBName != "" {
|
if field.DBName != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v})
|
||||||
} else if field.DataType != "" {
|
} else if field.DataType != "" {
|
||||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -447,14 +459,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(values) > 0 {
|
if len(values) > 0 {
|
||||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values})
|
||||||
return []clause.Expression{clause.And(conds...)}
|
return []clause.Expression{clause.And(conds...)}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -521,6 +533,7 @@ func (stmt *Statement) clone() *Statement {
|
|||||||
Context: stmt.Context,
|
Context: stmt.Context,
|
||||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||||
SkipHooks: stmt.SkipHooks,
|
SkipHooks: stmt.SkipHooks,
|
||||||
|
Result: stmt.Result,
|
||||||
}
|
}
|
||||||
|
|
||||||
if stmt.SQL.Len() > 0 {
|
if stmt.SQL.Len() > 0 {
|
||||||
|
@ -18,7 +18,7 @@ services:
|
|||||||
- POSTGRES_USER=gorm
|
- POSTGRES_USER=gorm
|
||||||
- POSTGRES_PASSWORD=gorm
|
- POSTGRES_PASSWORD=gorm
|
||||||
mssql:
|
mssql:
|
||||||
image: '${MSSQL_IMAGE}:2022-latest'
|
image: '${MSSQL_IMAGE}:latest'
|
||||||
ports:
|
ports:
|
||||||
- "127.0.0.1:9930:1433"
|
- "127.0.0.1:9930:1433"
|
||||||
environment:
|
environment:
|
||||||
|
@ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||||
|
db.Logger = DB.Logger
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Should open db success, but got %v", err)
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
}
|
}
|
||||||
|
852
tests/generics_test.go
Normal file
852
tests/generics_test.go
Normal file
@ -0,0 +1,852 @@
|
|||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gorm.io/driver/mysql"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
. "gorm.io/gorm/utils/tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenericsCreate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user := User{Name: "TestGenericsCreate", Age: 18}
|
||||||
|
err := gorm.G[User](DB).Create(ctx, &user)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
if user.ID == 0 {
|
||||||
|
t.Fatalf("no primary key found for %v", user)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if u.Name != user.Name || u.ID != user.ID {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if u.Name != user.Name || u.ID != user.ID {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if u.Name != user.Name || u.Age != 0 {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if u.Name != "" || u.Age != user.Age {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
}{}
|
||||||
|
if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil {
|
||||||
|
t.Fatalf("failed to scan user, got error: %v", err)
|
||||||
|
} else if result.Name != user.Name || uint(result.ID) != user.ID {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", result, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx)
|
||||||
|
if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name {
|
||||||
|
t.Errorf("failed to find map results, got %v, err %v", mapResult, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsCreateInBatches(t *testing.T) {
|
||||||
|
batch := []User{
|
||||||
|
{Name: "GenericsCreateInBatches1"},
|
||||||
|
{Name: "GenericsCreateInBatches2"},
|
||||||
|
{Name: "GenericsCreateInBatches3"},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u := range batch {
|
||||||
|
if u.ID == 0 {
|
||||||
|
t.Fatalf("no primary key found for %v", u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Count failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 3 {
|
||||||
|
t.Errorf("expected 3 records, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx)
|
||||||
|
if len(found) != len(batch) {
|
||||||
|
t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found))
|
||||||
|
}
|
||||||
|
|
||||||
|
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx)
|
||||||
|
if len(found) != 2 {
|
||||||
|
t.Errorf("expected %d from Raw Find, got %d", 2, len(found))
|
||||||
|
}
|
||||||
|
|
||||||
|
found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx)
|
||||||
|
if len(found) != 1 {
|
||||||
|
t.Errorf("expected %d from Raw Find, got %d", 1, len(found))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsExecAndUpdate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
name := "GenericsExec"
|
||||||
|
if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil {
|
||||||
|
t.Fatalf("Exec insert failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if u.Name != name || u.ID == 0 {
|
||||||
|
t.Errorf("found invalid user, got %v", u)
|
||||||
|
}
|
||||||
|
|
||||||
|
name += "Update"
|
||||||
|
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name)
|
||||||
|
if rows != 1 {
|
||||||
|
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if nu.Name != name || u.ID != nu.ID {
|
||||||
|
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18})
|
||||||
|
if rows != 1 {
|
||||||
|
t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to find user, got error: %v", err)
|
||||||
|
} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID {
|
||||||
|
t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsRow(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user := User{Name: "GenericsRow"}
|
||||||
|
if err := gorm.G[User](DB).Create(ctx, &user); err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx)
|
||||||
|
var name string
|
||||||
|
if err := row.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("Row scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if name != user.Name {
|
||||||
|
t.Errorf("expected %s, got %s", user.Name, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := User{Name: "GenericsRow2"}
|
||||||
|
if err := gorm.G[User](DB).Create(ctx, &user2); err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Rows failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("rows.Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("expected 2 rows, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsDelete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
u := User{Name: "GenericsDelete"}
|
||||||
|
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
|
||||||
|
t.Fatalf("Create failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete failed: %v", err)
|
||||||
|
}
|
||||||
|
if rows != 1 {
|
||||||
|
t.Errorf("expected 1 row deleted, got %d", rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
t.Fatalf("User after delete failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsFindInBatches(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
users := []User{
|
||||||
|
{Name: "GenericsFindBatchA"},
|
||||||
|
{Name: "GenericsFindBatchB"},
|
||||||
|
{Name: "GenericsFindBatchC"},
|
||||||
|
{Name: "GenericsFindBatchD"},
|
||||||
|
{Name: "GenericsFindBatchE"},
|
||||||
|
}
|
||||||
|
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error {
|
||||||
|
if len(chunk) > 2 {
|
||||||
|
t.Errorf("batch size exceed 2: got %d", len(chunk))
|
||||||
|
}
|
||||||
|
|
||||||
|
total += len(chunk)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FindInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if total != len(users) {
|
||||||
|
t.Errorf("expected total %d, got %d", len(users), total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsScopes(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}}
|
||||||
|
err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filterName1 := func(stmt *gorm.Statement) {
|
||||||
|
stmt.Where("name = ?", "GenericsScopes1")
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Scopes failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(results) != 1 || results[0].Name != "GenericsScopes1" {
|
||||||
|
t.Fatalf("Scopes expected 1, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx)
|
||||||
|
if len(notResult) != 2 {
|
||||||
|
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||||
|
} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" {
|
||||||
|
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx)
|
||||||
|
if len(orResult) != 2 {
|
||||||
|
t.Fatalf("expected 2 results, got %d", len(notResult))
|
||||||
|
} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" {
|
||||||
|
t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsJoins(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
|
u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}}
|
||||||
|
u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}}
|
||||||
|
u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}}
|
||||||
|
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||||
|
|
||||||
|
// Inner JOIN + WHERE
|
||||||
|
result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
|
}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Joins failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inner JOIN + WHERE with map
|
||||||
|
result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
db.Where(map[string]any{"name": u.Company.Name})
|
||||||
|
return nil
|
||||||
|
}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Joins failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Left JOIN w/o WHERE
|
||||||
|
result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Joins failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Left JOIN + Alias WHERE
|
||||||
|
result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
if joinTable.Name != "t" {
|
||||||
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
|
}
|
||||||
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
|
}).Where(map[string]any{"name": u.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Joins failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw Subquery JOIN + WHERE
|
||||||
|
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"),
|
||||||
|
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
if joinTable.Name != "t" {
|
||||||
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
|
}
|
||||||
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Raw subquery join failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw Subquery JOIN + WHERE + Select
|
||||||
|
result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"),
|
||||||
|
func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
if joinTable.Name != "t" {
|
||||||
|
t.Fatalf("Join table should be t, but got %v", joinTable.Name)
|
||||||
|
}
|
||||||
|
db.Where("?.name = ?", joinTable, u.Company.Name)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
).Where(map[string]any{"name": u2.Name}).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Raw subquery join failed: %v", err)
|
||||||
|
}
|
||||||
|
if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 {
|
||||||
|
t.Fatalf("Joins expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error {
|
||||||
|
return errors.New("join error")
|
||||||
|
}).First(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Joins should got error, but got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsNestedJoins(t *testing.T) {
|
||||||
|
users := []User{
|
||||||
|
{
|
||||||
|
Name: "generics-nested-joins-1",
|
||||||
|
Manager: &User{
|
||||||
|
Name: "generics-nested-joins-manager-1",
|
||||||
|
Company: Company{
|
||||||
|
Name: "generics-nested-joins-manager-company-1",
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{
|
||||||
|
Name: "generics-nested-joins-manager-namepet-1",
|
||||||
|
Toy: Toy{
|
||||||
|
Name: "generics-nested-joins-manager-namepet-toy-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "generics-nested-joins-2",
|
||||||
|
Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}),
|
||||||
|
NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
db.CreateInBatches(ctx, &users, 100)
|
||||||
|
|
||||||
|
var userIDs []uint
|
||||||
|
for _, user := range users {
|
||||||
|
userIDs = append(userIDs, user.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("Manager.Company"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil).
|
||||||
|
Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil).
|
||||||
|
Where(map[string]any{"id": userIDs}).Find(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load with joins, got error: %v", err)
|
||||||
|
} else if len(users2) != len(users) {
|
||||||
|
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(users2, func(i, j int) bool {
|
||||||
|
return users2[i].ID > users2[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(users, func(i, j int) bool {
|
||||||
|
return users[i].ID > users[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
for idx, user := range users {
|
||||||
|
// user
|
||||||
|
CheckUser(t, user, users2[idx])
|
||||||
|
if users2[idx].Manager == nil {
|
||||||
|
t.Fatalf("Failed to load Manager")
|
||||||
|
}
|
||||||
|
// manager
|
||||||
|
CheckUser(t, *user.Manager, *users2[idx].Manager)
|
||||||
|
// user pet
|
||||||
|
if users2[idx].NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.NamedPet, *users2[idx].NamedPet)
|
||||||
|
// manager pet
|
||||||
|
if users2[idx].Manager.NamedPet == nil {
|
||||||
|
t.Fatalf("Failed to load NamedPet")
|
||||||
|
}
|
||||||
|
CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsPreloads(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
|
u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7})
|
||||||
|
u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5})
|
||||||
|
u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3})
|
||||||
|
names := []string{u.Name, u2.Name, u3.Name}
|
||||||
|
|
||||||
|
db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10)
|
||||||
|
|
||||||
|
result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Preload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload expected %s, got %+v", u.Name, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Where("name = ?", u.Company.Name)
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Preload failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if result.Company.Name != u.Company.Name {
|
||||||
|
t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||||
|
}
|
||||||
|
} else if result.Company.Name != "" {
|
||||||
|
t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error {
|
||||||
|
return errors.New("preload error")
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Preload should failed, but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Dialector.Name() == "mysql" {
|
||||||
|
// mysql 5.7 doesn't support row_number()
|
||||||
|
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Dialector.Name() == "sqlserver" {
|
||||||
|
// sqlserver doesn't support order by in subquery
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name desc").LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name < result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name").LimitPerRecord(5)
|
||||||
|
return nil
|
||||||
|
}).Preload("Friends", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.Order("name")
|
||||||
|
return nil
|
||||||
|
}).Where("name in ?", names).Find(ctx)
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
if result.Name == u.Name {
|
||||||
|
if len(result.Pets) != len(u.Pets) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
if len(result.Friends) != len(u.Friends) {
|
||||||
|
t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets)
|
||||||
|
}
|
||||||
|
} else if len(result.Pets) != 5 || len(result.Friends) == 0 {
|
||||||
|
t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets)
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 1; i < len(result.Pets); i++ {
|
||||||
|
if result.Pets[i-1].Name > result.Pets[i].Name {
|
||||||
|
t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsNestedPreloads(t *testing.T) {
|
||||||
|
user := *GetUser("generics_nested_preload", Config{Pets: 2})
|
||||||
|
user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db := gorm.G[User](DB)
|
||||||
|
|
||||||
|
for idx, pet := range user.Pets {
|
||||||
|
pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Create(ctx, &user); err != nil {
|
||||||
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
return nil
|
||||||
|
}).Where(user.ID).Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to nested preload user")
|
||||||
|
}
|
||||||
|
CheckUser(t, user2, user)
|
||||||
|
if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 {
|
||||||
|
t.Fatalf("failed to nested preload")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Dialector.Name() == "mysql" {
|
||||||
|
// mysql 5.7 doesn't support row_number()
|
||||||
|
if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if DB.Dialector.Name() == "sqlserver" {
|
||||||
|
// sqlserver doesn't support order by in subquery
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error {
|
||||||
|
db.LimitPerRecord(3)
|
||||||
|
return nil
|
||||||
|
}).Where(user.ID).Take(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to nested preload user")
|
||||||
|
}
|
||||||
|
CheckUser(t, user3, user)
|
||||||
|
|
||||||
|
if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 {
|
||||||
|
t.Errorf("failed to nested preload with limit per record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsDistinct(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
batch := []User{
|
||||||
|
{Name: "GenericsDistinctDup"},
|
||||||
|
{Name: "GenericsDistinctDup"},
|
||||||
|
{Name: "GenericsDistinctUnique"},
|
||||||
|
}
|
||||||
|
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Distinct Find failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("expected 2 distinct names, got %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
for _, u := range results {
|
||||||
|
names = append(names, u.Name)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"}
|
||||||
|
if !reflect.DeepEqual(names, expected) {
|
||||||
|
t.Errorf("expected names %v, got %v", expected, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsGroupHaving(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
batch := []User{
|
||||||
|
{Name: "GenericsGroupHavingMulti"},
|
||||||
|
{Name: "GenericsGroupHavingMulti"},
|
||||||
|
{Name: "GenericsGroupHavingSingle"},
|
||||||
|
}
|
||||||
|
if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Group+Having Find failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(grouped) != 1 {
|
||||||
|
t.Errorf("expected 1 group with count>1, got %d", len(grouped))
|
||||||
|
} else if grouped[0].Name != "GenericsGroupHavingMulti" {
|
||||||
|
t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsSubQuery(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
users := []User{
|
||||||
|
{Name: "GenericsSubquery_1", Age: 10},
|
||||||
|
{Name: "GenericsSubquery_2", Age: 20},
|
||||||
|
{Name: "GenericsSubquery_3", Age: 30},
|
||||||
|
{Name: "GenericsSubquery_4", Age: 40},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil {
|
||||||
|
t.Fatalf("CreateInBatches failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 4 {
|
||||||
|
t.Errorf("Four users should be found, instead found %d", len(results))
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) != 3 {
|
||||||
|
t.Errorf("Three users should be found, instead found %d", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsUpsert(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
lang := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lang2 := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
}
|
||||||
|
|
||||||
|
lang3 := Language{Code: "upsert", Name: "Upsert"}
|
||||||
|
if err := gorm.G[Language](DB, clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "code"}},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}),
|
||||||
|
}).Create(ctx, &lang3); err != nil {
|
||||||
|
t.Fatalf("failed to upsert, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil {
|
||||||
|
t.Errorf("no error should happen when find languages with code, but got %v", err)
|
||||||
|
} else if len(langs) != 1 {
|
||||||
|
t.Errorf("should only find only 1 languages, but got %+v", langs)
|
||||||
|
} else if langs[0].Name != "upsert-new" {
|
||||||
|
t.Errorf("should update name on conflict, but got name %+v", langs[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsWithResult(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}}
|
||||||
|
|
||||||
|
result := gorm.WithResult()
|
||||||
|
err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create users WithResult")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected != 2 {
|
||||||
|
t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsReuse(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}}
|
||||||
|
|
||||||
|
err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create users")
|
||||||
|
}
|
||||||
|
|
||||||
|
reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%")
|
||||||
|
|
||||||
|
sg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
sg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil {
|
||||||
|
t.Errorf("failed to find user, got error: %v", err)
|
||||||
|
} else if u1.Name != users[0].Name || u1.ID != users[0].ID {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u1, users[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil {
|
||||||
|
t.Errorf("failed to find user, got error: %v", err)
|
||||||
|
} else if u2.Name != users[1].Name || u2.ID != users[1].ID {
|
||||||
|
t.Errorf("found invalid user, got %v, expect %v", u2, users[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil {
|
||||||
|
t.Errorf("failed to find user, got error: %v", err)
|
||||||
|
} else if len(users) != 2 {
|
||||||
|
t.Errorf("should find 2 users, but got %d", len(users))
|
||||||
|
}
|
||||||
|
sg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
sg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsWithTransaction(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
tx := DB.Begin()
|
||||||
|
if tx.Error != nil {
|
||||||
|
t.Fatalf("failed to begin transaction: %v", tx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}}
|
||||||
|
err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2)
|
||||||
|
|
||||||
|
count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Count failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("expected 2 records, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Rollback().Error; err != nil {
|
||||||
|
t.Fatalf("failed to rollback transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Count failed: %v", err)
|
||||||
|
}
|
||||||
|
if count2 != 0 {
|
||||||
|
t.Errorf("expected 0 records after rollback, got %d", count2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericsToSQL(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
|
gorm.G[User](tx).Limit(10).Find(ctx)
|
||||||
|
return tx
|
||||||
|
})
|
||||||
|
|
||||||
|
if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) {
|
||||||
|
t.Errorf("ToSQL: got wrong sql with Generics API %v", sql)
|
||||||
|
}
|
||||||
|
}
|
@ -2,8 +2,6 @@ module gorm.io/gorm/tests
|
|||||||
|
|
||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
toolchain go1.24.2
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jinzhu/now v1.1.5
|
github.com/jinzhu/now v1.1.5
|
||||||
@ -13,7 +11,7 @@ require (
|
|||||||
gorm.io/driver/postgres v1.5.11
|
gorm.io/driver/postgres v1.5.11
|
||||||
gorm.io/driver/sqlite v1.5.7
|
gorm.io/driver/sqlite v1.5.7
|
||||||
gorm.io/driver/sqlserver v1.5.4
|
gorm.io/driver/sqlserver v1.5.4
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.26.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@ -24,11 +22,11 @@ require (
|
|||||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/pgx/v5 v5.7.1 // indirect
|
github.com/jackc/pgx/v5 v5.7.5 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||||
github.com/microsoft/go-mssqldb v1.7.2 // indirect
|
github.com/microsoft/go-mssqldb v1.8.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||||
golang.org/x/crypto v0.38.0 // indirect
|
golang.org/x/crypto v0.38.0 // indirect
|
||||||
|
@ -419,7 +419,7 @@ func TestJoinsPreload_Issue7013(t *testing.T) {
|
|||||||
var entries []User
|
var entries []User
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
assert.NoError(t,
|
assert.NoError(t,
|
||||||
DB.Debug().Preload("Manager.Team").
|
DB.Preload("Manager.Team").
|
||||||
Joins("Manager.Company").
|
Joins("Manager.Company").
|
||||||
Find(&entries).Error)
|
Find(&entries).Error)
|
||||||
})
|
})
|
||||||
@ -456,7 +456,7 @@ func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) {
|
|||||||
var entries []Building
|
var entries []Building
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
assert.NoError(t,
|
assert.NoError(t,
|
||||||
DB.Debug().Preload("Owner.Furnitures").
|
DB.Preload("Owner.Furnitures").
|
||||||
Joins("Owner.Company").
|
Joins("Owner.Company").
|
||||||
Find(&entries).Error)
|
Find(&entries).Error)
|
||||||
})
|
})
|
||||||
@ -468,7 +468,7 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) {
|
|||||||
var entries []User
|
var entries []User
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
assert.NoError(t,
|
assert.NoError(t,
|
||||||
DB.Debug().Preload("Manager.Team").
|
DB.Preload("Manager.Team").
|
||||||
Joins("Manager.Company").
|
Joins("Manager.Company").
|
||||||
Where("1 <> 1").
|
Where("1 <> 1").
|
||||||
Find(&entries).Error)
|
Find(&entries).Error)
|
||||||
|
@ -520,38 +520,6 @@ func TestLRURemoveOldest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleLRU() {
|
|
||||||
// make cache with 10ms TTL and 5 max keys
|
|
||||||
cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10)
|
|
||||||
|
|
||||||
// set value under key1.
|
|
||||||
cache.Add("key1", "val1")
|
|
||||||
|
|
||||||
// get value under key1
|
|
||||||
r, ok := cache.Get("key1")
|
|
||||||
|
|
||||||
// check for OK value
|
|
||||||
if ok {
|
|
||||||
fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for cache to expire
|
|
||||||
time.Sleep(time.Millisecond * 100)
|
|
||||||
|
|
||||||
// get value under key1 after key expiration
|
|
||||||
r, ok = cache.Get("key1")
|
|
||||||
fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r)
|
|
||||||
|
|
||||||
// set value under key2, would evict old entry because it is already expired.
|
|
||||||
cache.Add("key2", "val2")
|
|
||||||
|
|
||||||
fmt.Printf("Cache len: %d\n", cache.Len())
|
|
||||||
// Output:
|
|
||||||
// value before expiration is found: true, value: "val1"
|
|
||||||
// value after expiration is found: false, value: ""
|
|
||||||
// Cache len: 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRand(tb testing.TB) int64 {
|
func getRand(tb testing.TB) int64 {
|
||||||
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
|
out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
@ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) {
|
|||||||
return tx2.Scan(&User{}).Error
|
return tx2.Scan(&User{}).Error
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) {
|
|||||||
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
return tx3.Where("user_id", user.ID).Delete(&Account{}).Error
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransactionWithDefaultTimeout(t *testing.T) {
|
||||||
|
db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect database, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tx := db.Begin()
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
if err = tx.Find(&User{}).Error; err == nil {
|
||||||
|
t.Errorf("should return error when transaction timeout, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user