Merge remote-tracking branch 'gorm/master'
# Conflicts: # tests/go.mod
This commit is contained in:
		
						commit
						87ede35ed4
					
				| @ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 					db.AddError(rows.Close()) | ||||
| 				}() | ||||
| 				gorm.Scan(rows, db, mode) | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			return | ||||
| @ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 		if db.Statement.Result != nil { | ||||
| 			db.Statement.Result.Result = result | ||||
| 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 		} | ||||
| 
 | ||||
| 		if db.RowsAffected == 0 { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| @ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { | ||||
| 			ok, mode := hasReturning(db, supportReturning) | ||||
| 			if !ok { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| 
 | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 					if db.Statement.Result != nil { | ||||
| 						db.Statement.Result.Result = result | ||||
| 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				return | ||||
| @ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { | ||||
| 
 | ||||
| 			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||
| 				gorm.Scan(rows, db, mode) | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 				db.AddError(rows.Close()) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @ -103,11 +103,11 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | ||||
| 				joined = true | ||||
| 				continue | ||||
| 			} | ||||
| 			joinNames := strings.SplitN(join, ".", 2) | ||||
| 			if len(joinNames) == 2 { | ||||
| 				if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { | ||||
| 			join0, join1, cut := strings.Cut(join, ".") | ||||
| 			if cut { | ||||
| 				if _, ok := relationships.Relations[join0]; ok && name == join0 { | ||||
| 					joined = true | ||||
| 					nestedJoins = append(nestedJoins, joinNames[1]) | ||||
| 					nestedJoins = append(nestedJoins, join1) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| @ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) | ||||
| 
 | ||||
| 	if len(values) != 0 { | ||||
| 		tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) | ||||
| 
 | ||||
| 		for _, cond := range conds { | ||||
| 			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { | ||||
| 				tx = fc(tx) | ||||
| @ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { | ||||
| 		if len(inlineConds) > 0 { | ||||
| 			tx = tx.Where(inlineConds[0], inlineConds[1:]...) | ||||
| 		} | ||||
| 
 | ||||
| 		if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -25,6 +25,10 @@ func Query(db *gorm.DB) { | ||||
| 				db.AddError(rows.Close()) | ||||
| 			}() | ||||
| 			gorm.Scan(rows, db, 0) | ||||
| 
 | ||||
| 			if db.Statement.Result != nil { | ||||
| 				db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			specifiedRelationsName := make(map[string]interface{}) | ||||
| 			specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} | ||||
| 			for _, join := range db.Statement.Joins { | ||||
| 				if db.Statement.Schema != nil { | ||||
| 					var isRelations bool // is relations or raw sql
 | ||||
| @ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 						nestedJoinNames := strings.Split(join.Name, ".") | ||||
| 						if len(nestedJoinNames) > 1 { | ||||
| 							isNestedJoin := true | ||||
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||
| 							currentRelations := db.Statement.Schema.Relationships.Relations | ||||
| 							for _, relname := range nestedJoinNames { | ||||
| 								// incomplete match, only treated as raw sql
 | ||||
| 								if relation, ok = currentRelations[relname]; ok { | ||||
| 									gussNestedRelations = append(gussNestedRelations, relation) | ||||
| 									guessNestedRelations = append(guessNestedRelations, relation) | ||||
| 									currentRelations = relation.FieldSchema.Relationships.Relations | ||||
| 								} else { | ||||
| 									isNestedJoin = false | ||||
| @ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 
 | ||||
| 							if isNestedJoin { | ||||
| 								isRelations = true | ||||
| 								relations = gussNestedRelations | ||||
| 								relations = guessNestedRelations | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 
 | ||||
| 					if isRelations { | ||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							tableAliasName := relation.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) | ||||
| 							} | ||||
| 
 | ||||
| 						genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { | ||||
| 							columnStmt := gorm.Statement{ | ||||
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||
| 								Selects: join.Selects, Omits: join.Omits, | ||||
| @ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							if join.Expression != nil { | ||||
| 								return clause.Join{ | ||||
| 									Type:       join.JoinType, | ||||
| 									Expression: join.Expression, | ||||
| 								} | ||||
| 							} | ||||
| 
 | ||||
| 							exprs := make([]clause.Expression, len(relation.References)) | ||||
| 							for idx, ref := range relation.References { | ||||
| 								if ref.OwnPrimaryKey { | ||||
| @ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) { | ||||
| 						} | ||||
| 
 | ||||
| 						parentTableName := clause.CurrentTable | ||||
| 						for _, rel := range relations { | ||||
| 						for idx, rel := range relations { | ||||
| 							// joins table alias like "Manager, Company, Manager__Company"
 | ||||
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok { | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | ||||
| 								specifiedRelationsName[nestedAlias] = nil | ||||
| 							curAliasName := rel.Name | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								curAliasName = utils.NestedRelationName(parentTableName, curAliasName) | ||||
| 							} | ||||
| 
 | ||||
| 							if parentTableName != clause.CurrentTable { | ||||
| 								parentTableName = utils.NestedRelationName(parentTableName, rel.Name) | ||||
| 							} else { | ||||
| 								parentTableName = rel.Name | ||||
| 							if _, ok := specifiedRelationsName[curAliasName]; !ok { | ||||
| 								aliasName := curAliasName | ||||
| 								if idx == len(relations)-1 && join.Alias != "" { | ||||
| 									aliasName = join.Alias | ||||
| 								} | ||||
| 
 | ||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) | ||||
| 								specifiedRelationsName[curAliasName] = aliasName | ||||
| 							} | ||||
| 
 | ||||
| 							parentTableName = curAliasName | ||||
| 						} | ||||
| 					} else { | ||||
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||
|  | ||||
| @ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { | ||||
| 		} | ||||
| 
 | ||||
| 		db.RowsAffected, _ = result.RowsAffected() | ||||
| 
 | ||||
| 		if db.Statement.Result != nil { | ||||
| 			db.Statement.Result.Result = result | ||||
| 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 					gorm.Scan(rows, db, mode) | ||||
| 					db.Statement.Dest = dest | ||||
| 					db.AddError(rows.Close()) | ||||
| 
 | ||||
| 					if db.Statement.Result != nil { | ||||
| 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||
| @ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { | ||||
| 				if db.AddError(err) == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} | ||||
| 
 | ||||
| 				if db.Statement.Result != nil { | ||||
| 					db.Statement.Result.Result = result | ||||
| 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -448,6 +448,7 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | ||||
| // Unscoped allows queries to include records marked as deleted,
 | ||||
| // overriding the soft deletion behavior.
 | ||||
| // Example:
 | ||||
| //
 | ||||
| //	var users []User
 | ||||
| //	db.Unscoped().Find(&users)
 | ||||
| //	// Retrieves all users, including deleted ones.
 | ||||
|  | ||||
| @ -1,5 +1,7 @@ | ||||
| package clause | ||||
| 
 | ||||
| import "gorm.io/gorm/utils" | ||||
| 
 | ||||
| type JoinType string | ||||
| 
 | ||||
| const ( | ||||
| @ -9,6 +11,30 @@ const ( | ||||
| 	RightJoin JoinType = "RIGHT" | ||||
| ) | ||||
| 
 | ||||
| type JoinTarget struct { | ||||
| 	Type        JoinType | ||||
| 	Association string | ||||
| 	Subquery    Expression | ||||
| 	Table       string | ||||
| } | ||||
| 
 | ||||
| func Has(name string) JoinTarget { | ||||
| 	return JoinTarget{Type: InnerJoin, Association: name} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinType) Association(name string) JoinTarget { | ||||
| 	return JoinTarget{Type: jt, Association: name} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { | ||||
| 	return JoinTarget{Type: jt, Association: name, Subquery: subquery} | ||||
| } | ||||
| 
 | ||||
| func (jt JoinTarget) As(name string) JoinTarget { | ||||
| 	jt.Table = name | ||||
| 	return jt | ||||
| } | ||||
| 
 | ||||
| // Join clause for from
 | ||||
| type Join struct { | ||||
| 	Type       JoinType | ||||
| @ -18,6 +44,12 @@ type Join struct { | ||||
| 	Expression Expression | ||||
| } | ||||
| 
 | ||||
| func JoinTable(names ...string) Table { | ||||
| 	return Table{ | ||||
| 		Name: utils.JoinNestedRelationNames(names), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (join Join) Build(builder Builder) { | ||||
| 	if join.Expression != nil { | ||||
| 		join.Expression.Build(builder) | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| @ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | ||||
| 		opt = opts[0] | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := tx.Statement.Context | ||||
| 	if _, ok := ctx.Deadline(); !ok { | ||||
| 		if db.Config.DefaultTransactionTimeout > 0 { | ||||
| 			ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	switch beginner := tx.Statement.ConnPool.(type) { | ||||
| 	case TxBeginner: | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||
| 	case ConnPoolBeginner: | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | ||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||
| 	default: | ||||
| 		err = ErrInvalidTransaction | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										605
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										605
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,605 @@ | ||||
| package gorm | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| ) | ||||
| 
 | ||||
| type result struct { | ||||
| 	Result       sql.Result | ||||
| 	RowsAffected int64 | ||||
| } | ||||
| 
 | ||||
| func (info *result) ModifyStatement(stmt *Statement) { | ||||
| 	stmt.Result = info | ||||
| } | ||||
| 
 | ||||
| // Build implements clause.Expression interface
 | ||||
| func (result) Build(clause.Builder) { | ||||
| } | ||||
| 
 | ||||
| func WithResult() *result { | ||||
| 	return &result{} | ||||
| } | ||||
| 
 | ||||
| type Interface[T any] interface { | ||||
| 	Raw(sql string, values ...interface{}) ExecInterface[T] | ||||
| 	Exec(ctx context.Context, sql string, values ...interface{}) error | ||||
| 	CreateInterface[T] | ||||
| } | ||||
| 
 | ||||
| type CreateInterface[T any] interface { | ||||
| 	ChainInterface[T] | ||||
| 	Table(name string, args ...interface{}) CreateInterface[T] | ||||
| 	Create(ctx context.Context, r *T) error | ||||
| 	CreateInBatches(ctx context.Context, r *[]T, batchSize int) error | ||||
| } | ||||
| 
 | ||||
| type ChainInterface[T any] interface { | ||||
| 	ExecInterface[T] | ||||
| 	Scopes(scopes ...func(db *Statement)) ChainInterface[T] | ||||
| 	Where(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Not(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Or(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Limit(offset int) ChainInterface[T] | ||||
| 	Offset(offset int) ChainInterface[T] | ||||
| 	Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] | ||||
| 	Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] | ||||
| 	Select(query string, args ...interface{}) ChainInterface[T] | ||||
| 	Omit(columns ...string) ChainInterface[T] | ||||
| 	MapColumns(m map[string]string) ChainInterface[T] | ||||
| 	Distinct(args ...interface{}) ChainInterface[T] | ||||
| 	Group(name string) ChainInterface[T] | ||||
| 	Having(query interface{}, args ...interface{}) ChainInterface[T] | ||||
| 	Order(value interface{}) ChainInterface[T] | ||||
| 
 | ||||
| 	Build(builder clause.Builder) | ||||
| 
 | ||||
| 	Delete(ctx context.Context) (rowsAffected int, err error) | ||||
| 	Update(ctx context.Context, name string, value any) (rowsAffected int, err error) | ||||
| 	Updates(ctx context.Context, t T) (rowsAffected int, err error) | ||||
| 	Count(ctx context.Context, column string) (result int64, err error) | ||||
| } | ||||
| 
 | ||||
| type ExecInterface[T any] interface { | ||||
| 	Scan(ctx context.Context, r interface{}) error | ||||
| 	First(context.Context) (T, error) | ||||
| 	Last(ctx context.Context) (T, error) | ||||
| 	Take(context.Context) (T, error) | ||||
| 	Find(ctx context.Context) ([]T, error) | ||||
| 	FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error | ||||
| 	Row(ctx context.Context) *sql.Row | ||||
| 	Rows(ctx context.Context) (*sql.Rows, error) | ||||
| } | ||||
| 
 | ||||
| type JoinBuilder interface { | ||||
| 	Select(...string) JoinBuilder | ||||
| 	Omit(...string) JoinBuilder | ||||
| 	Where(query interface{}, args ...interface{}) JoinBuilder | ||||
| 	Not(query interface{}, args ...interface{}) JoinBuilder | ||||
| 	Or(query interface{}, args ...interface{}) JoinBuilder | ||||
| } | ||||
| 
 | ||||
| type PreloadBuilder interface { | ||||
| 	Select(...string) PreloadBuilder | ||||
| 	Omit(...string) PreloadBuilder | ||||
| 	Where(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Not(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Or(query interface{}, args ...interface{}) PreloadBuilder | ||||
| 	Limit(offset int) PreloadBuilder | ||||
| 	Offset(offset int) PreloadBuilder | ||||
| 	Order(value interface{}) PreloadBuilder | ||||
| 	LimitPerRecord(num int) PreloadBuilder | ||||
| } | ||||
| 
 | ||||
| type op func(*DB) *DB | ||||
| 
 | ||||
| func G[T any](db *DB, opts ...clause.Expression) Interface[T] { | ||||
| 	v := &g[T]{ | ||||
| 		db:  db, | ||||
| 		ops: make([]op, 0, 5), | ||||
| 	} | ||||
| 
 | ||||
| 	if len(opts) > 0 { | ||||
| 		v.ops = append(v.ops, func(db *DB) *DB { | ||||
| 			return db.Clauses(opts...) | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	v.createG = &createG[T]{ | ||||
| 		chainG: chainG[T]{ | ||||
| 			execG: execG[T]{g: v}, | ||||
| 		}, | ||||
| 	} | ||||
| 	return v | ||||
| } | ||||
| 
 | ||||
| type g[T any] struct { | ||||
| 	*createG[T] | ||||
| 	db  *DB | ||||
| 	ops []op | ||||
| } | ||||
| 
 | ||||
| func (g *g[T]) apply(ctx context.Context) *DB { | ||||
| 	db := g.db | ||||
| 	if !db.DryRun { | ||||
| 		db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() | ||||
| 	} | ||||
| 
 | ||||
| 	for _, op := range g.ops { | ||||
| 		db = op(db) | ||||
| 	} | ||||
| 	return db | ||||
| } | ||||
| 
 | ||||
| func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { | ||||
| 	return execG[T]{g: &g[T]{ | ||||
| 		db: c.db, | ||||
| 		ops: append(c.ops, func(db *DB) *DB { | ||||
| 			return db.Raw(sql, values...) | ||||
| 		}), | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { | ||||
| 	return c.apply(ctx).Exec(sql, values...).Error | ||||
| } | ||||
| 
 | ||||
| type createG[T any] struct { | ||||
| 	chainG[T] | ||||
| } | ||||
| 
 | ||||
| func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { | ||||
| 	return createG[T]{c.with(func(db *DB) *DB { | ||||
| 		return db.Table(name, args...) | ||||
| 	})} | ||||
| } | ||||
| 
 | ||||
| func (c createG[T]) Create(ctx context.Context, r *T) error { | ||||
| 	return c.g.apply(ctx).Create(r).Error | ||||
| } | ||||
| 
 | ||||
| func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { | ||||
| 	return c.g.apply(ctx).CreateInBatches(r, batchSize).Error | ||||
| } | ||||
| 
 | ||||
| type chainG[T any] struct { | ||||
| 	execG[T] | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) getInstance() *DB { | ||||
| 	var r T | ||||
| 	return c.g.apply(context.Background()).Model(r).getInstance() | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) with(v op) chainG[T] { | ||||
| 	return chainG[T]{ | ||||
| 		execG: execG[T]{g: &g[T]{ | ||||
| 			db:  c.g.db, | ||||
| 			ops: append(append([]op(nil), c.g.ops...), v), | ||||
| 		}}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		for _, fc := range scopes { | ||||
| 			fc(db.Statement) | ||||
| 		} | ||||
| 		return db | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Table(name, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Where(query, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Not(query, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Or(query, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Limit(offset int) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Limit(offset) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Offset(offset int) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Offset(offset) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| type joinBuilder struct { | ||||
| 	db *DB | ||||
| } | ||||
| 
 | ||||
| func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *joinBuilder) Select(columns ...string) JoinBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *joinBuilder) Omit(columns ...string) JoinBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| type preloadBuilder struct { | ||||
| 	limitPerRecord int | ||||
| 	db             *DB | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { | ||||
| 	q.db.Where(query, args...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { | ||||
| 	q.db.Select(columns) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { | ||||
| 	q.db.Omit(columns...) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Limit(limit int) PreloadBuilder { | ||||
| 	q.db.Limit(limit) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Offset(offset int) PreloadBuilder { | ||||
| 	q.db.Offset(offset) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { | ||||
| 	q.db.Order(value) | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { | ||||
| 	q.limitPerRecord = num | ||||
| 	return q | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		if jt.Table == "" { | ||||
| 			jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name | ||||
| 		} | ||||
| 
 | ||||
| 		q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} | ||||
| 		if on != nil { | ||||
| 			if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		j := join{ | ||||
| 			Name:     jt.Association, | ||||
| 			Alias:    jt.Table, | ||||
| 			Selects:  q.db.Statement.Selects, | ||||
| 			Omits:    q.db.Statement.Omits, | ||||
| 			JoinType: jt.Type, | ||||
| 		} | ||||
| 
 | ||||
| 		if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { | ||||
| 			j.On = &where | ||||
| 		} | ||||
| 
 | ||||
| 		if jt.Subquery != nil { | ||||
| 			joinType := j.JoinType | ||||
| 			if joinType == "" { | ||||
| 				joinType = clause.LeftJoin | ||||
| 			} | ||||
| 
 | ||||
| 			if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok { | ||||
| 				stmt := db.getInstance().Statement | ||||
| 				if len(j.Selects) == 0 { | ||||
| 					j.Selects = stmt.Selects | ||||
| 				} | ||||
| 				if len(j.Omits) == 0 { | ||||
| 					j.Omits = stmt.Omits | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} | ||||
| 
 | ||||
| 			if j.On != nil { | ||||
| 				expr.SQL += " ON ?" | ||||
| 				expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) | ||||
| 			} | ||||
| 
 | ||||
| 			j.Expression = expr | ||||
| 		} | ||||
| 
 | ||||
| 		db.Statement.Joins = append(db.Statement.Joins, j) | ||||
| 		sort.Slice(db.Statement.Joins, func(i, j int) bool { | ||||
| 			return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name | ||||
| 		}) | ||||
| 		return db | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Select(query, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Omit(columns...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.MapColumns(m) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Distinct(args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Group(name string) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Group(name) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Having(query, args...) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Order(value interface{}) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Order(value) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { | ||||
| 	return c.with(func(db *DB) *DB { | ||||
| 		return db.Preload(association, func(tx *DB) *DB { | ||||
| 			q := preloadBuilder{db: tx.getInstance()} | ||||
| 			if query != nil { | ||||
| 				if err := query(&q); err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			relation, ok := db.Statement.Schema.Relationships.Relations[association] | ||||
| 			if !ok { | ||||
| 				if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { | ||||
| 					relationships := db.Statement.Schema.Relationships | ||||
| 					for _, field := range preloadFields { | ||||
| 						var ok bool | ||||
| 						relation, ok = relationships.Relations[field] | ||||
| 						if ok { | ||||
| 							relationships = relation.FieldSchema.Relationships | ||||
| 						} else { | ||||
| 							db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 							return nil | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					db.AddError(fmt.Errorf("relation %s not found", association)) | ||||
| 					return nil | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if q.limitPerRecord > 0 { | ||||
| 				if relation.JoinTable != nil { | ||||
| 					tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) | ||||
| 					return tx | ||||
| 				} | ||||
| 
 | ||||
| 				refColumns := []clause.Column{} | ||||
| 				for _, rel := range relation.References { | ||||
| 					if rel.OwnPrimaryKey { | ||||
| 						refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if len(refColumns) != 0 { | ||||
| 					selectExpr := clause.CommaExpression{} | ||||
| 					for _, column := range q.db.Statement.Selects { | ||||
| 						selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) | ||||
| 					} | ||||
| 
 | ||||
| 					if len(selectExpr.Exprs) == 0 { | ||||
| 						selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} | ||||
| 					} | ||||
| 
 | ||||
| 					partitionBy := clause.CommaExpression{} | ||||
| 					for _, column := range refColumns { | ||||
| 						partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) | ||||
| 					} | ||||
| 
 | ||||
| 					rnnColumn := clause.Column{Name: "gorm_preload_rnn"} | ||||
| 					sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" | ||||
| 					vars := []interface{}{partitionBy} | ||||
| 					if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { | ||||
| 						vars = append(vars, orderBy) | ||||
| 					} else { | ||||
| 						vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ | ||||
| 							Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, | ||||
| 						}}) | ||||
| 					} | ||||
| 					vars = append(vars, rnnColumn) | ||||
| 
 | ||||
| 					selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) | ||||
| 
 | ||||
| 					q.db.Clauses(clause.Select{Expression: selectExpr}) | ||||
| 
 | ||||
| 					return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			return q.db | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { | ||||
| 	r := new(T) | ||||
| 	res := c.g.apply(ctx).Delete(r) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { | ||||
| 	var r T | ||||
| 	res := c.g.apply(ctx).Model(r).Update(name, value) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { | ||||
| 	res := c.g.apply(ctx).Updates(t) | ||||
| 	return int(res.RowsAffected), res.Error | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { | ||||
| 	var r T | ||||
| 	err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (c chainG[T]) Build(builder clause.Builder) { | ||||
| 	subdb := c.getInstance() | ||||
| 	subdb.Logger = logger.Discard | ||||
| 	subdb.DryRun = true | ||||
| 
 | ||||
| 	if stmt, ok := builder.(*Statement); ok { | ||||
| 		if subdb.Statement.SQL.Len() > 0 { | ||||
| 			var ( | ||||
| 				vars = subdb.Statement.Vars | ||||
| 				sql  = subdb.Statement.SQL.String() | ||||
| 			) | ||||
| 
 | ||||
| 			subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | ||||
| 			for _, vv := range vars { | ||||
| 				subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | ||||
| 				bindvar := strings.Builder{} | ||||
| 				subdb.BindVarTo(&bindvar, subdb.Statement, vv) | ||||
| 				sql = strings.Replace(sql, bindvar.String(), "?", 1) | ||||
| 			} | ||||
| 
 | ||||
| 			subdb.Statement.SQL.Reset() | ||||
| 			subdb.Statement.Vars = stmt.Vars | ||||
| 			if strings.Contains(sql, "@") { | ||||
| 				clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||
| 			} else { | ||||
| 				clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||
| 			} | ||||
| 		} else { | ||||
| 			subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) | ||||
| 			subdb.callbacks.Query().Execute(subdb) | ||||
| 		} | ||||
| 
 | ||||
| 		builder.WriteString(subdb.Statement.SQL.String()) | ||||
| 		stmt.Vars = subdb.Statement.Vars | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type execG[T any] struct { | ||||
| 	g *g[T] | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) First(ctx context.Context) (T, error) { | ||||
| 	var r T | ||||
| 	err := g.g.apply(ctx).First(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Scan(ctx context.Context, result interface{}) error { | ||||
| 	var r T | ||||
| 	err := g.g.apply(ctx).Model(r).Find(result).Error | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Last(ctx context.Context) (T, error) { | ||||
| 	var r T | ||||
| 	err := g.g.apply(ctx).Last(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Take(ctx context.Context) (T, error) { | ||||
| 	var r T | ||||
| 	err := g.g.apply(ctx).Take(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Find(ctx context.Context) ([]T, error) { | ||||
| 	var r []T | ||||
| 	err := g.g.apply(ctx).Find(&r).Error | ||||
| 	return r, err | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { | ||||
| 	var data []T | ||||
| 	return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { | ||||
| 		return fc(data, batch) | ||||
| 	}).Error | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Row(ctx context.Context) *sql.Row { | ||||
| 	return g.g.apply(ctx).Row() | ||||
| } | ||||
| 
 | ||||
| func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { | ||||
| 	return g.g.apply(ctx).Rows() | ||||
| } | ||||
							
								
								
									
										12
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								gorm.go
									
									
									
									
									
								
							| @ -22,6 +22,8 @@ type Config struct { | ||||
| 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | ||||
| 	// You can disable it by setting `SkipDefaultTransaction` to true
 | ||||
| 	SkipDefaultTransaction    bool | ||||
| 	DefaultTransactionTimeout time.Duration | ||||
| 
 | ||||
| 	// NamingStrategy tables, columns naming strategy
 | ||||
| 	NamingStrategy schema.Namer | ||||
| 	// FullSaveAssociations full save associations
 | ||||
| @ -135,12 +137,16 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { | ||||
| 		return isConfig && !isConfig2 | ||||
| 	}) | ||||
| 
 | ||||
| 	var skipAfterInitialize bool | ||||
| 	for _, opt := range opts { | ||||
| 		if opt != nil { | ||||
| 			if applyErr := opt.Apply(config); applyErr != nil { | ||||
| 				return nil, applyErr | ||||
| 			} | ||||
| 			defer func(opt Option) { | ||||
| 				if skipAfterInitialize { | ||||
| 					return | ||||
| 				} | ||||
| 				if errr := opt.AfterInitialize(db); errr != nil { | ||||
| 					err = errr | ||||
| 				} | ||||
| @ -192,6 +198,10 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { | ||||
| 			if db, _ := db.DB(); db != nil { | ||||
| 				_ = db.Close() | ||||
| 			} | ||||
| 
 | ||||
| 			// DB is not initialized, so we skip AfterInitialize
 | ||||
| 			skipAfterInitialize = true | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if config.TranslateError { | ||||
| @ -519,7 +529,7 @@ func (db *DB) Use(plugin Plugin) error { | ||||
| //				.First(&User{})
 | ||||
| //	})
 | ||||
| func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { | ||||
| 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) | ||||
| 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) | ||||
| 	stmt := tx.Statement | ||||
| 
 | ||||
| 	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||
|  | ||||
							
								
								
									
										9
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								scan.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm/schema" | ||||
| @ -244,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | ||||
| 							matchedFieldCount[column] = 1 | ||||
| 						} | ||||
| 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | ||||
| 						aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) | ||||
| 						for _, join := range db.Statement.Joins { | ||||
| 							if join.Alias == aliasName { | ||||
| 								names = append(strings.Split(join.Name, "."), names[len(names)-1]) | ||||
| 								break | ||||
| 							} | ||||
| 						} | ||||
| 
 | ||||
| 						if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||
| 							subNameCount := len(names) | ||||
| 							// nested relation fields
 | ||||
|  | ||||
| @ -318,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | ||||
| 	} | ||||
| 
 | ||||
| 	if val, ok := field.TagSettings["TYPE"]; ok { | ||||
| 		switch DataType(strings.ToLower(val)) { | ||||
| 		lowerVal := DataType(strings.ToLower(val)) | ||||
| 		switch lowerVal { | ||||
| 		case Bool, Int, Uint, Float, String, Time, Bytes: | ||||
| 			field.DataType = DataType(strings.ToLower(val)) | ||||
| 			field.DataType = lowerVal | ||||
| 		default: | ||||
| 			field.DataType = DataType(val) | ||||
| 		} | ||||
|  | ||||
| @ -105,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) { | ||||
| 				var ( | ||||
| 					name       string | ||||
| 					tag        = strings.Join(v[1:], ":") | ||||
| 					idx        = strings.Index(tag, ",") | ||||
| 					idx        = strings.IndexByte(tag, ',') | ||||
| 					tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") | ||||
| 					settings   = ParseTagSetting(tagSetting, ",") | ||||
| 					length, _  = strconv.Atoi(settings["LENGTH"]) | ||||
|  | ||||
| @ -78,7 +78,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | ||||
| 	cacheStore := schema.cacheStore | ||||
| 
 | ||||
| 	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { | ||||
| 		schema.err = err | ||||
| 		schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| @ -663,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | ||||
| 					if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && | ||||
| 						rel.References[idx].PrimaryValue == ref.PrimaryValue) { | ||||
| 						matched = false | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| @ -675,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | ||||
| 
 | ||||
| 	var ( | ||||
| 		name     string | ||||
| 		idx      = strings.Index(str, ",") | ||||
| 		idx      = strings.IndexByte(str, ',') | ||||
| 		settings = ParseTagSetting(str, ",") | ||||
| 	) | ||||
| 
 | ||||
| @ -762,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref | ||||
| } | ||||
| 
 | ||||
| func copyableDataType(str DataType) bool { | ||||
| 	lowerStr := strings.ToLower(string(str)) | ||||
| 	for _, s := range []string{"auto_increment", "primary key"} { | ||||
| 		if strings.Contains(strings.ToLower(string(str)), s) { | ||||
| 		if strings.Contains(lowerStr, s) { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										45
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								statement.go
									
									
									
									
									
								
							| @ -47,14 +47,17 @@ type Statement struct { | ||||
| 	attrs                []interface{} | ||||
| 	assigns              []interface{} | ||||
| 	scopes               []func(*DB) *DB | ||||
| 	Result               *result | ||||
| } | ||||
| 
 | ||||
| type join struct { | ||||
| 	Name       string | ||||
| 	Alias      string | ||||
| 	Conds      []interface{} | ||||
| 	On         *clause.Where | ||||
| 	Selects    []string | ||||
| 	Omits      []string | ||||
| 	Expression clause.Expression | ||||
| 	JoinType   clause.JoinType | ||||
| } | ||||
| 
 | ||||
| @ -205,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 			} else { | ||||
| 				writer.WriteString("(NULL)") | ||||
| 			} | ||||
| 		case *DB: | ||||
| 			subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() | ||||
| 			if v.Statement.SQL.Len() > 0 { | ||||
| 		case interface{ getInstance() *DB }: | ||||
| 			cv := v.getInstance() | ||||
| 
 | ||||
| 			subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() | ||||
| 			if cv.Statement.SQL.Len() > 0 { | ||||
| 				var ( | ||||
| 					vars = subdb.Statement.Vars | ||||
| 					sql  = v.Statement.SQL.String() | ||||
| 					sql  = cv.Statement.SQL.String() | ||||
| 				) | ||||
| 
 | ||||
| 				subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | ||||
| 				for _, vv := range vars { | ||||
| 					subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | ||||
| 					bindvar := strings.Builder{} | ||||
| 					v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) | ||||
| 					cv.BindVarTo(&bindvar, subdb.Statement, vv) | ||||
| 					sql = strings.Replace(sql, bindvar.String(), "?", 1) | ||||
| 				} | ||||
| 
 | ||||
| @ -321,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 			arg, _ = valuer.Value() | ||||
| 		} | ||||
| 
 | ||||
| 		curTable := stmt.Table | ||||
| 		if curTable == "" { | ||||
| 			curTable = clause.CurrentTable | ||||
| 		} | ||||
| 
 | ||||
| 		switch v := arg.(type) { | ||||
| 		case clause.Expression: | ||||
| 			conds = append(conds, v) | ||||
| @ -351,7 +361,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 			sort.Strings(keys) | ||||
| 
 | ||||
| 			for _, key := range keys { | ||||
| 				conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 				column := clause.Column{Name: key, Table: curTable} | ||||
| 				conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 			} | ||||
| 		case map[string]interface{}: | ||||
| 			keys := make([]string, 0, len(v)) | ||||
| @ -362,12 +373,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 
 | ||||
| 			for _, key := range keys { | ||||
| 				reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) | ||||
| 				column := clause.Column{Name: key, Table: curTable} | ||||
| 				switch reflectValue.Kind() { | ||||
| 				case reflect.Slice, reflect.Array: | ||||
| 					if _, ok := v[key].(driver.Valuer); ok { | ||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 					} else if _, ok := v[key].(Valuer); ok { | ||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 					} else { | ||||
| 						// optimize reflect value length
 | ||||
| 						valueLen := reflectValue.Len() | ||||
| @ -376,10 +388,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 							values[i] = reflectValue.Index(i).Interface() | ||||
| 						} | ||||
| 
 | ||||
| 						conds = append(conds, clause.IN{Column: key, Values: values}) | ||||
| 						conds = append(conds, clause.IN{Column: column, Values: values}) | ||||
| 					} | ||||
| 				default: | ||||
| 					conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | ||||
| 					conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||
| 				} | ||||
| 			} | ||||
| 		default: | ||||
| @ -406,9 +418,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 						if selected || (!restricted && field.Readable) { | ||||
| 							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { | ||||
| 								if field.DBName != "" { | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||
| 								} else if field.DataType != "" { | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | ||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| @ -420,9 +432,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 							if selected || (!restricted && field.Readable) { | ||||
| 								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { | ||||
| 									if field.DBName != "" { | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||
| 									} else if field.DataType != "" { | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | ||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| @ -447,14 +459,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | ||||
| 						} | ||||
| 
 | ||||
| 						if len(values) > 0 { | ||||
| 							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) | ||||
| 							conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) | ||||
| 							return []clause.Expression{clause.And(conds...)} | ||||
| 						} | ||||
| 						return nil | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) | ||||
| 				conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @ -521,6 +533,7 @@ func (stmt *Statement) clone() *Statement { | ||||
| 		Context:              stmt.Context, | ||||
| 		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, | ||||
| 		SkipHooks:            stmt.SkipHooks, | ||||
| 		Result:               stmt.Result, | ||||
| 	} | ||||
| 
 | ||||
| 	if stmt.SQL.Len() > 0 { | ||||
|  | ||||
| @ -18,7 +18,7 @@ services: | ||||
|       - POSTGRES_USER=gorm | ||||
|       - POSTGRES_PASSWORD=gorm | ||||
|   mssql: | ||||
|     image: '${MSSQL_IMAGE}:2022-latest' | ||||
|     image: '${MSSQL_IMAGE}:latest' | ||||
|     ports: | ||||
|       - "127.0.0.1:9930:1433" | ||||
|     environment: | ||||
|  | ||||
| @ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) { | ||||
| 	}() | ||||
| 
 | ||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) | ||||
| 	db.Logger = DB.Logger | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Should open db success, but got %v", err) | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										875
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										875
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,875 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"gorm.io/driver/mysql" | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestGenericsCreate(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	user := User{Name: "TestGenericsCreate", Age: 18} | ||||
| 	err := gorm.G[User](DB).Create(ctx, &user) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 	if user.ID == 0 { | ||||
| 		t.Fatalf("no primary key found for %v", user) | ||||
| 	} | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != user.Name || u.ID != user.ID { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != user.Name || u.ID != user.ID { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != user.Name || u.Age != 0 { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| 	if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != "" || u.Age != user.Age { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||
| 	} | ||||
| 
 | ||||
| 	result := struct { | ||||
| 		ID   int | ||||
| 		Name string | ||||
| 	}{} | ||||
| 	if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { | ||||
| 		t.Fatalf("failed to scan user, got error: %v", err) | ||||
| 	} else if result.Name != user.Name || uint(result.ID) != user.ID { | ||||
| 		t.Errorf("found invalid user, got %v, expect %v", result, user) | ||||
| 	} | ||||
| 
 | ||||
| 	mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) | ||||
| 	if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { | ||||
| 		t.Errorf("failed to find map results, got %v, err %v", mapResult, err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsCreateInBatches(t *testing.T) { | ||||
| 	batch := []User{ | ||||
| 		{Name: "GenericsCreateInBatches1"}, | ||||
| 		{Name: "GenericsCreateInBatches2"}, | ||||
| 		{Name: "GenericsCreateInBatches3"}, | ||||
| 	} | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	for _, u := range batch { | ||||
| 		if u.ID == 0 { | ||||
| 			t.Fatalf("no primary key found for %v", u) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Count failed: %v", err) | ||||
| 	} | ||||
| 	if count != 3 { | ||||
| 		t.Errorf("expected 3 records, got %d", count) | ||||
| 	} | ||||
| 
 | ||||
| 	found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) | ||||
| 	if len(found) != len(batch) { | ||||
| 		t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) | ||||
| 	} | ||||
| 
 | ||||
| 	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx) | ||||
| 	if len(found) != 2 { | ||||
| 		t.Errorf("expected %d from Raw Find, got %d", 2, len(found)) | ||||
| 	} | ||||
| 
 | ||||
| 	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx) | ||||
| 	if len(found) != 1 { | ||||
| 		t.Errorf("expected %d from Raw Find, got %d", 1, len(found)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsExecAndUpdate(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	name := "GenericsExec" | ||||
| 	if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { | ||||
| 		t.Fatalf("Exec insert failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if u.Name != name || u.ID == 0 { | ||||
| 		t.Errorf("found invalid user, got %v", u) | ||||
| 	} | ||||
| 
 | ||||
| 	name += "Update" | ||||
| 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) | ||||
| 	if rows != 1 { | ||||
| 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||
| 	} | ||||
| 
 | ||||
| 	nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if nu.Name != name || u.ID != nu.ID { | ||||
| 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) | ||||
| 	if rows != 1 { | ||||
| 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||
| 	} | ||||
| 
 | ||||
| 	nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to find user, got error: %v", err) | ||||
| 	} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { | ||||
| 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsRow(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	user := User{Name: "GenericsRow"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &user); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx) | ||||
| 	var name string | ||||
| 	if err := row.Scan(&name); err != nil { | ||||
| 		t.Fatalf("Row scan failed: %v", err) | ||||
| 	} | ||||
| 	if name != user.Name { | ||||
| 		t.Errorf("expected %s, got %s", user.Name, name) | ||||
| 	} | ||||
| 
 | ||||
| 	user2 := User{Name: "GenericsRow2"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 	rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Rows failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	count := 0 | ||||
| 	for rows.Next() { | ||||
| 		var name string | ||||
| 		if err := rows.Scan(&name); err != nil { | ||||
| 			t.Fatalf("rows.Scan failed: %v", err) | ||||
| 		} | ||||
| 		count++ | ||||
| 	} | ||||
| 	if count != 2 { | ||||
| 		t.Errorf("expected 2 rows, got %d", count) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsDelete(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	u := User{Name: "GenericsDelete"} | ||||
| 	if err := gorm.G[User](DB).Create(ctx, &u); err != nil { | ||||
| 		t.Fatalf("Create failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Delete failed: %v", err) | ||||
| 	} | ||||
| 	if rows != 1 { | ||||
| 		t.Errorf("expected 1 row deleted, got %d", rows) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) | ||||
| 	if err != gorm.ErrRecordNotFound { | ||||
| 		t.Fatalf("User after delete failed: %v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsFindInBatches(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	users := []User{ | ||||
| 		{Name: "GenericsFindBatchA"}, | ||||
| 		{Name: "GenericsFindBatchB"}, | ||||
| 		{Name: "GenericsFindBatchC"}, | ||||
| 		{Name: "GenericsFindBatchD"}, | ||||
| 		{Name: "GenericsFindBatchE"}, | ||||
| 	} | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	total := 0 | ||||
| 	err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error { | ||||
| 		if len(chunk) > 2 { | ||||
| 			t.Errorf("batch size exceed 2: got %d", len(chunk)) | ||||
| 		} | ||||
| 
 | ||||
| 		total += len(chunk) | ||||
| 		return nil | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("FindInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if total != len(users) { | ||||
| 		t.Errorf("expected total %d, got %d", len(users), total) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsScopes(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}} | ||||
| 	err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	filterName1 := func(stmt *gorm.Statement) { | ||||
| 		stmt.Where("name = ?", "GenericsScopes1") | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Scopes failed: %v", err) | ||||
| 	} | ||||
| 	if len(results) != 1 || results[0].Name != "GenericsScopes1" { | ||||
| 		t.Fatalf("Scopes expected 1, got %d", len(results)) | ||||
| 	} | ||||
| 
 | ||||
| 	notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx) | ||||
| 	if len(notResult) != 2 { | ||||
| 		t.Fatalf("expected 2 results, got %d", len(notResult)) | ||||
| 	} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" { | ||||
| 		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name) | ||||
| 	} | ||||
| 
 | ||||
| 	orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx) | ||||
| 	if len(orResult) != 2 { | ||||
| 		t.Fatalf("expected 2 results, got %d", len(notResult)) | ||||
| 	} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" { | ||||
| 		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsJoins(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} | ||||
| 	u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} | ||||
| 	u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} | ||||
| 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||
| 
 | ||||
| 	// Inner JOIN + WHERE
 | ||||
| 	result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		return nil | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Inner JOIN + WHERE with map
 | ||||
| 	result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		db.Where(map[string]any{"name": u.Company.Name}) | ||||
| 		return nil | ||||
| 	}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Left JOIN w/o WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Left JOIN + Alias WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		if joinTable.Name != "t" { | ||||
| 			t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 		} | ||||
| 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 		return nil | ||||
| 	}).Where(map[string]any{"name": u.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Joins failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Raw Subquery JOIN + WHERE
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), | ||||
| 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 			if joinTable.Name != "t" { | ||||
| 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 			} | ||||
| 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Raw subquery join failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	// Raw Subquery JOIN + WHERE + Select
 | ||||
| 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), | ||||
| 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 			if joinTable.Name != "t" { | ||||
| 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||
| 			} | ||||
| 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||
| 			return nil | ||||
| 		}, | ||||
| 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Raw subquery join failed: %v", err) | ||||
| 	} | ||||
| 	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 { | ||||
| 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||
| 		return errors.New("join error") | ||||
| 	}).First(ctx) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("Joins should got error, but got nil") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsNestedJoins(t *testing.T) { | ||||
| 	users := []User{ | ||||
| 		{ | ||||
| 			Name: "generics-nested-joins-1", | ||||
| 			Manager: &User{ | ||||
| 				Name: "generics-nested-joins-manager-1", | ||||
| 				Company: Company{ | ||||
| 					Name: "generics-nested-joins-manager-company-1", | ||||
| 				}, | ||||
| 				NamedPet: &Pet{ | ||||
| 					Name: "generics-nested-joins-manager-namepet-1", | ||||
| 					Toy: Toy{ | ||||
| 						Name: "generics-nested-joins-manager-namepet-toy-1", | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name:     "generics-nested-joins-2", | ||||
| 			Manager:  GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), | ||||
| 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 	db.CreateInBatches(ctx, &users, 100) | ||||
| 
 | ||||
| 	var userIDs []uint | ||||
| 	for _, user := range users { | ||||
| 		userIDs = append(userIDs, user.ID) | ||||
| 	} | ||||
| 
 | ||||
| 	users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("Manager.Company"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). | ||||
| 		Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). | ||||
| 		Where(map[string]any{"id": userIDs}).Find(ctx) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to load with joins, got error: %v", err) | ||||
| 	} else if len(users2) != len(users) { | ||||
| 		t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) | ||||
| 	} | ||||
| 
 | ||||
| 	sort.Slice(users2, func(i, j int) bool { | ||||
| 		return users2[i].ID > users2[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	sort.Slice(users, func(i, j int) bool { | ||||
| 		return users[i].ID > users[j].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	for idx, user := range users { | ||||
| 		// user
 | ||||
| 		CheckUser(t, user, users2[idx]) | ||||
| 		if users2[idx].Manager == nil { | ||||
| 			t.Fatalf("Failed to load Manager") | ||||
| 		} | ||||
| 		// manager
 | ||||
| 		CheckUser(t, *user.Manager, *users2[idx].Manager) | ||||
| 		// user pet
 | ||||
| 		if users2[idx].NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) | ||||
| 		// manager pet
 | ||||
| 		if users2[idx].Manager.NamedPet == nil { | ||||
| 			t.Fatalf("Failed to load NamedPet") | ||||
| 		} | ||||
| 		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsPreloads(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) | ||||
| 	u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) | ||||
| 	u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) | ||||
| 	names := []string{u.Name, u2.Name, u3.Name} | ||||
| 
 | ||||
| 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||
| 
 | ||||
| 	result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { | ||||
| 		t.Fatalf("Preload expected %s, got %+v", u.Name, result) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Where("name = ?", u.Company.Name) | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Preload failed: %v", err) | ||||
| 	} | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if result.Company.Name != u.Company.Name { | ||||
| 				t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||
| 			} | ||||
| 		} else if result.Company.Name != "" { | ||||
| 			t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||
| 		return errors.New("preload error") | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("Preload should failed, but got nil") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Dialector.Name() == "mysql" { | ||||
| 		// mysql 5.7 doesn't support row_number()
 | ||||
| 		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Dialector.Name() == "sqlserver" { | ||||
| 		// sqlserver doesn't support order by in subquery
 | ||||
| 		return | ||||
| 	} | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name desc").LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name < result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name").LimitPerRecord(5) | ||||
| 		return nil | ||||
| 	}).Preload("Friends", func(db gorm.PreloadBuilder) error { | ||||
| 		db.Order("name") | ||||
| 		return nil | ||||
| 	}).Where("name in ?", names).Find(ctx) | ||||
| 
 | ||||
| 	for _, result := range results { | ||||
| 		if result.Name == u.Name { | ||||
| 			if len(result.Pets) != len(u.Pets) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 			if len(result.Friends) != len(u.Friends) { | ||||
| 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||
| 			} | ||||
| 		} else if len(result.Pets) != 5 || len(result.Friends) == 0 { | ||||
| 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 		for i := 1; i < len(result.Pets); i++ { | ||||
| 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||
| 				t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsNestedPreloads(t *testing.T) { | ||||
| 	user := *GetUser("generics_nested_preload", Config{Pets: 2}) | ||||
| 	user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	db := gorm.G[User](DB) | ||||
| 
 | ||||
| 	for idx, pet := range user.Pets { | ||||
| 		pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} | ||||
| 	} | ||||
| 
 | ||||
| 	if err := db.Create(ctx, &user); err != nil { | ||||
| 		t.Fatalf("errors happened when create: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		return nil | ||||
| 	}).Where(user.ID).Take(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to nested preload user") | ||||
| 	} | ||||
| 	CheckUser(t, user2, user) | ||||
| 	if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { | ||||
| 		t.Fatalf("failed to nested preload") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Dialector.Name() == "mysql" { | ||||
| 		// mysql 5.7 doesn't support row_number()
 | ||||
| 		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	if DB.Dialector.Name() == "sqlserver" { | ||||
| 		// sqlserver doesn't support order by in subquery
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||
| 		db.LimitPerRecord(3) | ||||
| 		return nil | ||||
| 	}).Where(user.ID).Take(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to nested preload user") | ||||
| 	} | ||||
| 	CheckUser(t, user3, user) | ||||
| 
 | ||||
| 	if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { | ||||
| 		t.Errorf("failed to nested preload with limit per record") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsDistinct(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	batch := []User{ | ||||
| 		{Name: "GenericsDistinctDup"}, | ||||
| 		{Name: "GenericsDistinctDup"}, | ||||
| 		{Name: "GenericsDistinctUnique"}, | ||||
| 	} | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Distinct Find failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(results) != 2 { | ||||
| 		t.Errorf("expected 2 distinct names, got %d", len(results)) | ||||
| 	} | ||||
| 
 | ||||
| 	var names []string | ||||
| 	for _, u := range results { | ||||
| 		names = append(names, u.Name) | ||||
| 	} | ||||
| 	sort.Strings(names) | ||||
| 	expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"} | ||||
| 	if !reflect.DeepEqual(names, expected) { | ||||
| 		t.Errorf("expected names %v, got %v", expected, names) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsGroupHaving(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 
 | ||||
| 	batch := []User{ | ||||
| 		{Name: "GenericsGroupHavingMulti"}, | ||||
| 		{Name: "GenericsGroupHavingMulti"}, | ||||
| 		{Name: "GenericsGroupHavingSingle"}, | ||||
| 	} | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Group+Having Find failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(grouped) != 1 { | ||||
| 		t.Errorf("expected 1 group with count>1, got %d", len(grouped)) | ||||
| 	} else if grouped[0].Name != "GenericsGroupHavingMulti" { | ||||
| 		t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsSubQuery(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	users := []User{ | ||||
| 		{Name: "GenericsSubquery_1", Age: 10}, | ||||
| 		{Name: "GenericsSubquery_2", Age: 20}, | ||||
| 		{Name: "GenericsSubquery_3", Age: 30}, | ||||
| 		{Name: "GenericsSubquery_4", Age: 40}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(results) != 4 { | ||||
| 		t.Errorf("Four users should be found, instead found %d", len(results)) | ||||
| 	} | ||||
| 
 | ||||
| 	results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(results) != 3 { | ||||
| 		t.Errorf("Three users should be found, instead found %d", len(results)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsUpsert(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	lang := Language{Code: "upsert", Name: "Upsert"} | ||||
| 
 | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	lang2 := Language{Code: "upsert", Name: "Upsert"} | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||
| 	} else if len(langs) != 1 { | ||||
| 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||
| 	} | ||||
| 
 | ||||
| 	lang3 := Language{Code: "upsert", Name: "Upsert"} | ||||
| 	if err := gorm.G[Language](DB, clause.OnConflict{ | ||||
| 		Columns:   []clause.Column{{Name: "code"}}, | ||||
| 		DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), | ||||
| 	}).Create(ctx, &lang3); err != nil { | ||||
| 		t.Fatalf("failed to upsert, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { | ||||
| 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||
| 	} else if len(langs) != 1 { | ||||
| 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||
| 	} else if langs[0].Name != "upsert-new" { | ||||
| 		t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsWithResult(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} | ||||
| 
 | ||||
| 	result := gorm.WithResult() | ||||
| 	err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to create users WithResult") | ||||
| 	} | ||||
| 
 | ||||
| 	if result.RowsAffected != 2 { | ||||
| 		t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsReuse(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}} | ||||
| 
 | ||||
| 	err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("failed to create users") | ||||
| 	} | ||||
| 
 | ||||
| 	reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%") | ||||
| 
 | ||||
| 	sg := sync.WaitGroup{} | ||||
| 	for i := 0; i < 5; i++ { | ||||
| 		sg.Add(1) | ||||
| 
 | ||||
| 		go func() { | ||||
| 			if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil { | ||||
| 				t.Errorf("failed to find user, got error: %v", err) | ||||
| 			} else if u1.Name != users[0].Name || u1.ID != users[0].ID { | ||||
| 				t.Errorf("found invalid user, got %v, expect %v", u1, users[0]) | ||||
| 			} | ||||
| 
 | ||||
| 			if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil { | ||||
| 				t.Errorf("failed to find user, got error: %v", err) | ||||
| 			} else if u2.Name != users[1].Name || u2.ID != users[1].ID { | ||||
| 				t.Errorf("found invalid user, got %v, expect %v", u2, users[1]) | ||||
| 			} | ||||
| 
 | ||||
| 			if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil { | ||||
| 				t.Errorf("failed to find user, got error: %v", err) | ||||
| 			} else if len(users) != 2 { | ||||
| 				t.Errorf("should find 2 users, but got %d", len(users)) | ||||
| 			} | ||||
| 			sg.Done() | ||||
| 		}() | ||||
| 	} | ||||
| 	sg.Wait() | ||||
| } | ||||
| 
 | ||||
| func TestGenericsWithTransaction(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	tx := DB.Begin() | ||||
| 	if tx.Error != nil { | ||||
| 		t.Fatalf("failed to begin transaction: %v", tx.Error) | ||||
| 	} | ||||
| 
 | ||||
| 	users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} | ||||
| 	err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) | ||||
| 
 | ||||
| 	count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Count failed: %v", err) | ||||
| 	} | ||||
| 	if count != 2 { | ||||
| 		t.Errorf("expected 2 records, got %d", count) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := tx.Rollback().Error; err != nil { | ||||
| 		t.Fatalf("failed to rollback transaction: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Count failed: %v", err) | ||||
| 	} | ||||
| 	if count2 != 0 { | ||||
| 		t.Errorf("expected 0 records after rollback, got %d", count2) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsToSQL(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { | ||||
| 		gorm.G[User](tx).Limit(10).Find(ctx) | ||||
| 		return tx | ||||
| 	}) | ||||
| 
 | ||||
| 	if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) { | ||||
| 		t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGenericsScanUUID(t *testing.T) { | ||||
| 	ctx := context.Background() | ||||
| 	users := []User{ | ||||
| 		{Name: uuid.NewString(), Age: 21}, | ||||
| 		{Name: uuid.NewString(), Age: 22}, | ||||
| 		{Name: uuid.NewString(), Age: 23}, | ||||
| 	} | ||||
| 
 | ||||
| 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil { | ||||
| 		t.Fatalf("CreateInBatches failed: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	userIds := []uuid.UUID{} | ||||
| 	if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 { | ||||
| 		t.Fatalf("Scan failed: %v, userids %v", err, userIds) | ||||
| 	} | ||||
| 
 | ||||
| 	if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name { | ||||
| 		t.Fatalf("wrong uuid scanned") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							| @ -2,15 +2,13 @@ module gorm.io/gorm/tests | ||||
| 
 | ||||
| go 1.23.0 | ||||
| 
 | ||||
| toolchain go1.24.2 | ||||
| 
 | ||||
| require ( | ||||
| 	github.com/google/uuid v1.6.0 | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	github.com/lib/pq v1.10.9 | ||||
| 	github.com/stretchr/testify v1.10.0 | ||||
| 	gorm.io/driver/mysql v1.5.7 | ||||
| 	gorm.io/driver/postgres v1.5.11 | ||||
| 	gorm.io/driver/postgres v1.6.0 | ||||
| 	gorm.io/driver/sqlite v1.5.7 | ||||
| 	gorm.io/driver/sqlserver v1.6.0 | ||||
| 	gorm.io/gorm v1.30.0 | ||||
| @ -25,6 +23,7 @@ require ( | ||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||
| 	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect | ||||
| 	github.com/jackc/pgx/v5 v5.7.5 // indirect | ||||
| 	github.com/jackc/puddle/v2 v2.2.2 // indirect | ||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||
| 	github.com/kr/text v0.2.0 // indirect | ||||
| 	github.com/mattn/go-sqlite3 v1.14.28 // indirect | ||||
| @ -32,12 +31,9 @@ require ( | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/rogpeppe/go-internal v1.12.0 // indirect | ||||
| 	golang.org/x/crypto v0.38.0 // indirect | ||||
| 	golang.org/x/sync v0.14.0 // indirect | ||||
| 	golang.org/x/text v0.25.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
| 
 | ||||
| replace gorm.io/gorm => ../ | ||||
| 
 | ||||
| replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 | ||||
| 
 | ||||
| replace github.com/microsoft/go-mssqldb => github.com/microsoft/go-mssqldb v1.7.0 | ||||
|  | ||||
| @ -419,7 +419,7 @@ func TestJoinsPreload_Issue7013(t *testing.T) { | ||||
| 	var entries []User | ||||
| 	assert.NotPanics(t, func() { | ||||
| 		assert.NoError(t, | ||||
| 			DB.Debug().Preload("Manager.Team"). | ||||
| 			DB.Preload("Manager.Team"). | ||||
| 				Joins("Manager.Company"). | ||||
| 				Find(&entries).Error) | ||||
| 	}) | ||||
| @ -456,7 +456,7 @@ func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) { | ||||
| 	var entries []Building | ||||
| 	assert.NotPanics(t, func() { | ||||
| 		assert.NoError(t, | ||||
| 			DB.Debug().Preload("Owner.Furnitures"). | ||||
| 			DB.Preload("Owner.Furnitures"). | ||||
| 				Joins("Owner.Company"). | ||||
| 				Find(&entries).Error) | ||||
| 	}) | ||||
| @ -468,7 +468,7 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { | ||||
| 	var entries []User | ||||
| 	assert.NotPanics(t, func() { | ||||
| 		assert.NoError(t, | ||||
| 			DB.Debug().Preload("Manager.Team"). | ||||
| 			DB.Preload("Manager.Team"). | ||||
| 				Joins("Manager.Company"). | ||||
| 				Where("1 <> 1"). | ||||
| 				Find(&entries).Error) | ||||
|  | ||||
| @ -520,38 +520,6 @@ func TestLRURemoveOldest(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func ExampleLRU() { | ||||
| 	// make cache with 10ms TTL and 5 max keys
 | ||||
| 	cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10) | ||||
| 
 | ||||
| 	// set value under key1.
 | ||||
| 	cache.Add("key1", "val1") | ||||
| 
 | ||||
| 	// get value under key1
 | ||||
| 	r, ok := cache.Get("key1") | ||||
| 
 | ||||
| 	// check for OK value
 | ||||
| 	if ok { | ||||
| 		fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r) | ||||
| 	} | ||||
| 
 | ||||
| 	// wait for cache to expire
 | ||||
| 	time.Sleep(time.Millisecond * 100) | ||||
| 
 | ||||
| 	// get value under key1 after key expiration
 | ||||
| 	r, ok = cache.Get("key1") | ||||
| 	fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r) | ||||
| 
 | ||||
| 	// set value under key2, would evict old entry because it is already expired.
 | ||||
| 	cache.Add("key2", "val2") | ||||
| 
 | ||||
| 	fmt.Printf("Cache len: %d\n", cache.Len()) | ||||
| 	// Output:
 | ||||
| 	// value before expiration is found: true, value: "val1"
 | ||||
| 	// value after expiration is found: false, value: ""
 | ||||
| 	// Cache len: 1
 | ||||
| } | ||||
| 
 | ||||
| func getRand(tb testing.TB) int64 { | ||||
| 	out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -5,7 +5,6 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @ -17,7 +16,6 @@ import ( | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	"gorm.io/gorm/clause" | ||||
| 	"gorm.io/gorm/logger" | ||||
| 	"gorm.io/gorm/migrator" | ||||
| 	"gorm.io/gorm/schema" | ||||
| 	"gorm.io/gorm/utils" | ||||
| @ -1211,101 +1209,6 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestInvalidCachedPlanPrepareStmt(t *testing.T) { | ||||
| 	if DB.Dialector.Name() != "postgres" { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("Open err:%v", err) | ||||
| 	} | ||||
| 	if debug := os.Getenv("DEBUG"); debug == "true" { | ||||
| 		db.Logger = db.Logger.LogMode(logger.Info) | ||||
| 	} else if debug == "false" { | ||||
| 		db.Logger = db.Logger.LogMode(logger.Silent) | ||||
| 	} | ||||
| 
 | ||||
| 	type Object1 struct { | ||||
| 		ID uint | ||||
| 	} | ||||
| 	type Object2 struct { | ||||
| 		ID     uint | ||||
| 		Field1 int `gorm:"type:int8"` | ||||
| 	} | ||||
| 	type Object3 struct { | ||||
| 		ID     uint | ||||
| 		Field1 int `gorm:"type:int4"` | ||||
| 	} | ||||
| 	type Object4 struct { | ||||
| 		ID     uint | ||||
| 		Field2 int | ||||
| 	} | ||||
| 	db.Migrator().DropTable("objects") | ||||
| 
 | ||||
| 	err = db.Table("objects").AutoMigrate(&Object1{}) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("AutoMigrate err:%v", err) | ||||
| 	} | ||||
| 	err = db.Table("objects").Create(&Object1{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("create err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// AddColumn
 | ||||
| 	err = db.Table("objects").AutoMigrate(&Object2{}) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("AutoMigrate err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.Table("objects").Take(&Object2{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("take err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// AlterColumn
 | ||||
| 	err = db.Table("objects").AutoMigrate(&Object3{}) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("AutoMigrate err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.Table("objects").Take(&Object3{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("take err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	// AddColumn
 | ||||
| 	err = db.Table("objects").AutoMigrate(&Object4{}) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("AutoMigrate err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.Table("objects").Take(&Object4{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("take err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("RenameColumn err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.Table("objects").Take(&Object4{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("take err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("RenameColumn err:%v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = db.Table("objects").Take(&Object4{}).Error | ||||
| 	if err != nil { | ||||
| 		t.Errorf("take err:%v", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { | ||||
| 	type DiffType struct { | ||||
| 		ID   uint | ||||
|  | ||||
| @ -4,6 +4,7 @@ import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"gorm.io/gorm" | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| @ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) { | ||||
| 			return tx2.Scan(&User{}).Error | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| @ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) { | ||||
| 			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestTransactionWithDefaultTimeout(t *testing.T) { | ||||
| 	db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to connect database, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	tx := db.Begin() | ||||
| 	time.Sleep(3 * time.Second) | ||||
| 	if err = tx.Find(&User{}).Error; err == nil { | ||||
| 		t.Errorf("should return error when transaction timeout, got error %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Chise1
						Chise1