Merge remote-tracking branch 'gorm/master'
# Conflicts: # tests/go.mod
This commit is contained in:
		
						commit
						87ede35ed4
					
				| @ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 					db.AddError(rows.Close()) | 					db.AddError(rows.Close()) | ||||||
| 				}() | 				}() | ||||||
| 				gorm.Scan(rows, db, mode) | 				gorm.Scan(rows, db, mode) | ||||||
|  | 
 | ||||||
|  | 				if db.Statement.Result != nil { | ||||||
|  | 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			return | 			return | ||||||
| @ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 
 | ||||||
|  | 		if db.Statement.Result != nil { | ||||||
|  | 			db.Statement.Result.Result = result | ||||||
|  | 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if db.RowsAffected == 0 { | 		if db.RowsAffected == 0 { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -157,8 +157,14 @@ func Delete(config *Config) func(db *gorm.DB) { | |||||||
| 			ok, mode := hasReturning(db, supportReturning) | 			ok, mode := hasReturning(db, supportReturning) | ||||||
| 			if !ok { | 			if !ok { | ||||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | 
 | ||||||
| 				if db.AddError(err) == nil { | 				if db.AddError(err) == nil { | ||||||
| 					db.RowsAffected, _ = result.RowsAffected() | 					db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 
 | ||||||
|  | 					if db.Statement.Result != nil { | ||||||
|  | 						db.Statement.Result.Result = result | ||||||
|  | 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				return | 				return | ||||||
| @ -166,6 +172,10 @@ func Delete(config *Config) func(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | 			if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||||
| 				gorm.Scan(rows, db, mode) | 				gorm.Scan(rows, db, mode) | ||||||
|  | 
 | ||||||
|  | 				if db.Statement.Result != nil { | ||||||
|  | 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 				} | ||||||
| 				db.AddError(rows.Close()) | 				db.AddError(rows.Close()) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -103,11 +103,11 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati | |||||||
| 				joined = true | 				joined = true | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			joinNames := strings.SplitN(join, ".", 2) | 			join0, join1, cut := strings.Cut(join, ".") | ||||||
| 			if len(joinNames) == 2 { | 			if cut { | ||||||
| 				if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { | 				if _, ok := relationships.Relations[join0]; ok && name == join0 { | ||||||
| 					joined = true | 					joined = true | ||||||
| 					nestedJoins = append(nestedJoins, joinNames[1]) | 					nestedJoins = append(nestedJoins, join1) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | |||||||
| 	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) | 	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) | ||||||
| 
 | 
 | ||||||
| 	if len(values) != 0 { | 	if len(values) != 0 { | ||||||
|  | 		tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) | ||||||
|  | 
 | ||||||
| 		for _, cond := range conds { | 		for _, cond := range conds { | ||||||
| 			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { | 			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { | ||||||
| 				tx = fc(tx) | 				tx = fc(tx) | ||||||
| @ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { | 		if len(inlineConds) > 0 { | ||||||
|  | 			tx = tx.Where(inlineConds[0], inlineConds[1:]...) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -25,6 +25,10 @@ func Query(db *gorm.DB) { | |||||||
| 				db.AddError(rows.Close()) | 				db.AddError(rows.Close()) | ||||||
| 			}() | 			}() | ||||||
| 			gorm.Scan(rows, db, 0) | 			gorm.Scan(rows, db, 0) | ||||||
|  | 
 | ||||||
|  | 			if db.Statement.Result != nil { | ||||||
|  | 				db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			specifiedRelationsName := make(map[string]interface{}) | 			specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} | ||||||
| 			for _, join := range db.Statement.Joins { | 			for _, join := range db.Statement.Joins { | ||||||
| 				if db.Statement.Schema != nil { | 				if db.Statement.Schema != nil { | ||||||
| 					var isRelations bool // is relations or raw sql
 | 					var isRelations bool // is relations or raw sql
 | ||||||
| @ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 						nestedJoinNames := strings.Split(join.Name, ".") | 						nestedJoinNames := strings.Split(join.Name, ".") | ||||||
| 						if len(nestedJoinNames) > 1 { | 						if len(nestedJoinNames) > 1 { | ||||||
| 							isNestedJoin := true | 							isNestedJoin := true | ||||||
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | 							guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) | ||||||
| 							currentRelations := db.Statement.Schema.Relationships.Relations | 							currentRelations := db.Statement.Schema.Relationships.Relations | ||||||
| 							for _, relname := range nestedJoinNames { | 							for _, relname := range nestedJoinNames { | ||||||
| 								// incomplete match, only treated as raw sql
 | 								// incomplete match, only treated as raw sql
 | ||||||
| 								if relation, ok = currentRelations[relname]; ok { | 								if relation, ok = currentRelations[relname]; ok { | ||||||
| 									gussNestedRelations = append(gussNestedRelations, relation) | 									guessNestedRelations = append(guessNestedRelations, relation) | ||||||
| 									currentRelations = relation.FieldSchema.Relationships.Relations | 									currentRelations = relation.FieldSchema.Relationships.Relations | ||||||
| 								} else { | 								} else { | ||||||
| 									isNestedJoin = false | 									isNestedJoin = false | ||||||
| @ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 							if isNestedJoin { | 							if isNestedJoin { | ||||||
| 								isRelations = true | 								isRelations = true | ||||||
| 								relations = gussNestedRelations | 								relations = guessNestedRelations | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 
 | 
 | ||||||
| 					if isRelations { | 					if isRelations { | ||||||
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { | 						genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { | ||||||
| 							tableAliasName := relation.Name |  | ||||||
| 							if parentTableName != clause.CurrentTable { |  | ||||||
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) |  | ||||||
| 							} |  | ||||||
| 
 |  | ||||||
| 							columnStmt := gorm.Statement{ | 							columnStmt := gorm.Statement{ | ||||||
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema, | ||||||
| 								Selects: join.Selects, Omits: join.Omits, | 								Selects: join.Selects, Omits: join.Omits, | ||||||
| @ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| 
 | 
 | ||||||
|  | 							if join.Expression != nil { | ||||||
|  | 								return clause.Join{ | ||||||
|  | 									Type:       join.JoinType, | ||||||
|  | 									Expression: join.Expression, | ||||||
|  | 								} | ||||||
|  | 							} | ||||||
|  | 
 | ||||||
| 							exprs := make([]clause.Expression, len(relation.References)) | 							exprs := make([]clause.Expression, len(relation.References)) | ||||||
| 							for idx, ref := range relation.References { | 							for idx, ref := range relation.References { | ||||||
| 								if ref.OwnPrimaryKey { | 								if ref.OwnPrimaryKey { | ||||||
| @ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) { | |||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						parentTableName := clause.CurrentTable | 						parentTableName := clause.CurrentTable | ||||||
| 						for _, rel := range relations { | 						for idx, rel := range relations { | ||||||
| 							// joins table alias like "Manager, Company, Manager__Company"
 | 							// joins table alias like "Manager, Company, Manager__Company"
 | ||||||
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) | 							curAliasName := rel.Name | ||||||
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok { | 							if parentTableName != clause.CurrentTable { | ||||||
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) | 								curAliasName = utils.NestedRelationName(parentTableName, curAliasName) | ||||||
| 								specifiedRelationsName[nestedAlias] = nil |  | ||||||
| 							} | 							} | ||||||
| 
 | 
 | ||||||
| 							if parentTableName != clause.CurrentTable { | 							if _, ok := specifiedRelationsName[curAliasName]; !ok { | ||||||
| 								parentTableName = utils.NestedRelationName(parentTableName, rel.Name) | 								aliasName := curAliasName | ||||||
| 							} else { | 								if idx == len(relations)-1 && join.Alias != "" { | ||||||
| 								parentTableName = rel.Name | 									aliasName = join.Alias | ||||||
|  | 								} | ||||||
|  | 
 | ||||||
|  | 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) | ||||||
|  | 								specifiedRelationsName[curAliasName] = aliasName | ||||||
| 							} | 							} | ||||||
|  | 
 | ||||||
|  | 							parentTableName = curAliasName | ||||||
| 						} | 						} | ||||||
| 					} else { | 					} else { | ||||||
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | 						fromClause.Joins = append(fromClause.Joins, clause.Join{ | ||||||
|  | |||||||
| @ -13,5 +13,10 @@ func RawExec(db *gorm.DB) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		db.RowsAffected, _ = result.RowsAffected() | 		db.RowsAffected, _ = result.RowsAffected() | ||||||
|  | 
 | ||||||
|  | 		if db.Statement.Result != nil { | ||||||
|  | 			db.Statement.Result.Result = result | ||||||
|  | 			db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -92,6 +92,10 @@ func Update(config *Config) func(db *gorm.DB) { | |||||||
| 					gorm.Scan(rows, db, mode) | 					gorm.Scan(rows, db, mode) | ||||||
| 					db.Statement.Dest = dest | 					db.Statement.Dest = dest | ||||||
| 					db.AddError(rows.Close()) | 					db.AddError(rows.Close()) | ||||||
|  | 
 | ||||||
|  | 					if db.Statement.Result != nil { | ||||||
|  | 						db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| @ -99,6 +103,11 @@ func Update(config *Config) func(db *gorm.DB) { | |||||||
| 				if db.AddError(err) == nil { | 				if db.AddError(err) == nil { | ||||||
| 					db.RowsAffected, _ = result.RowsAffected() | 					db.RowsAffected, _ = result.RowsAffected() | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				if db.Statement.Result != nil { | ||||||
|  | 					db.Statement.Result.Result = result | ||||||
|  | 					db.Statement.Result.RowsAffected = db.RowsAffected | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { | |||||||
| // Unscoped allows queries to include records marked as deleted,
 | // Unscoped allows queries to include records marked as deleted,
 | ||||||
| // overriding the soft deletion behavior.
 | // overriding the soft deletion behavior.
 | ||||||
| // Example:
 | // Example:
 | ||||||
| //    var users []User
 | //
 | ||||||
| //    db.Unscoped().Find(&users)
 | //	var users []User
 | ||||||
| //    // Retrieves all users, including deleted ones.
 | //	db.Unscoped().Find(&users)
 | ||||||
|  | //	// Retrieves all users, including deleted ones.
 | ||||||
| func (db *DB) Unscoped() (tx *DB) { | func (db *DB) Unscoped() (tx *DB) { | ||||||
| 	tx = db.getInstance() | 	tx = db.getInstance() | ||||||
| 	tx.Statement.Unscoped = true | 	tx.Statement.Unscoped = true | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package clause | package clause | ||||||
| 
 | 
 | ||||||
|  | import "gorm.io/gorm/utils" | ||||||
|  | 
 | ||||||
| type JoinType string | type JoinType string | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -9,6 +11,30 @@ const ( | |||||||
| 	RightJoin JoinType = "RIGHT" | 	RightJoin JoinType = "RIGHT" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | type JoinTarget struct { | ||||||
|  | 	Type        JoinType | ||||||
|  | 	Association string | ||||||
|  | 	Subquery    Expression | ||||||
|  | 	Table       string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Has(name string) JoinTarget { | ||||||
|  | 	return JoinTarget{Type: InnerJoin, Association: name} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (jt JoinType) Association(name string) JoinTarget { | ||||||
|  | 	return JoinTarget{Type: jt, Association: name} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { | ||||||
|  | 	return JoinTarget{Type: jt, Association: name, Subquery: subquery} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (jt JoinTarget) As(name string) JoinTarget { | ||||||
|  | 	jt.Table = name | ||||||
|  | 	return jt | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Join clause for from
 | // Join clause for from
 | ||||||
| type Join struct { | type Join struct { | ||||||
| 	Type       JoinType | 	Type       JoinType | ||||||
| @ -18,6 +44,12 @@ type Join struct { | |||||||
| 	Expression Expression | 	Expression Expression | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func JoinTable(names ...string) Table { | ||||||
|  | 	return Table{ | ||||||
|  | 		Name: utils.JoinNestedRelationNames(names), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (join Join) Build(builder Builder) { | func (join Join) Build(builder Builder) { | ||||||
| 	if join.Expression != nil { | 	if join.Expression != nil { | ||||||
| 		join.Expression.Build(builder) | 		join.Expression.Build(builder) | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package gorm | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { | |||||||
| 		opt = opts[0] | 		opt = opts[0] | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	ctx := tx.Statement.Context | ||||||
|  | 	if _, ok := ctx.Deadline(); !ok { | ||||||
|  | 		if db.Config.DefaultTransactionTimeout > 0 { | ||||||
|  | 			ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	switch beginner := tx.Statement.ConnPool.(type) { | 	switch beginner := tx.Statement.ConnPool.(type) { | ||||||
| 	case TxBeginner: | 	case TxBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||||
| 	case ConnPoolBeginner: | 	case ConnPoolBeginner: | ||||||
| 		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) | 		tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) | ||||||
| 	default: | 	default: | ||||||
| 		err = ErrInvalidTransaction | 		err = ErrInvalidTransaction | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										605
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										605
									
								
								generics.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,605 @@ | |||||||
|  | package gorm | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 	"sort" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
|  | 	"gorm.io/gorm/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type result struct { | ||||||
|  | 	Result       sql.Result | ||||||
|  | 	RowsAffected int64 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (info *result) ModifyStatement(stmt *Statement) { | ||||||
|  | 	stmt.Result = info | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Build implements clause.Expression interface
 | ||||||
|  | func (result) Build(clause.Builder) { | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func WithResult() *result { | ||||||
|  | 	return &result{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Interface[T any] interface { | ||||||
|  | 	Raw(sql string, values ...interface{}) ExecInterface[T] | ||||||
|  | 	Exec(ctx context.Context, sql string, values ...interface{}) error | ||||||
|  | 	CreateInterface[T] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type CreateInterface[T any] interface { | ||||||
|  | 	ChainInterface[T] | ||||||
|  | 	Table(name string, args ...interface{}) CreateInterface[T] | ||||||
|  | 	Create(ctx context.Context, r *T) error | ||||||
|  | 	CreateInBatches(ctx context.Context, r *[]T, batchSize int) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ChainInterface[T any] interface { | ||||||
|  | 	ExecInterface[T] | ||||||
|  | 	Scopes(scopes ...func(db *Statement)) ChainInterface[T] | ||||||
|  | 	Where(query interface{}, args ...interface{}) ChainInterface[T] | ||||||
|  | 	Not(query interface{}, args ...interface{}) ChainInterface[T] | ||||||
|  | 	Or(query interface{}, args ...interface{}) ChainInterface[T] | ||||||
|  | 	Limit(offset int) ChainInterface[T] | ||||||
|  | 	Offset(offset int) ChainInterface[T] | ||||||
|  | 	Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] | ||||||
|  | 	Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] | ||||||
|  | 	Select(query string, args ...interface{}) ChainInterface[T] | ||||||
|  | 	Omit(columns ...string) ChainInterface[T] | ||||||
|  | 	MapColumns(m map[string]string) ChainInterface[T] | ||||||
|  | 	Distinct(args ...interface{}) ChainInterface[T] | ||||||
|  | 	Group(name string) ChainInterface[T] | ||||||
|  | 	Having(query interface{}, args ...interface{}) ChainInterface[T] | ||||||
|  | 	Order(value interface{}) ChainInterface[T] | ||||||
|  | 
 | ||||||
|  | 	Build(builder clause.Builder) | ||||||
|  | 
 | ||||||
|  | 	Delete(ctx context.Context) (rowsAffected int, err error) | ||||||
|  | 	Update(ctx context.Context, name string, value any) (rowsAffected int, err error) | ||||||
|  | 	Updates(ctx context.Context, t T) (rowsAffected int, err error) | ||||||
|  | 	Count(ctx context.Context, column string) (result int64, err error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type ExecInterface[T any] interface { | ||||||
|  | 	Scan(ctx context.Context, r interface{}) error | ||||||
|  | 	First(context.Context) (T, error) | ||||||
|  | 	Last(ctx context.Context) (T, error) | ||||||
|  | 	Take(context.Context) (T, error) | ||||||
|  | 	Find(ctx context.Context) ([]T, error) | ||||||
|  | 	FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error | ||||||
|  | 	Row(ctx context.Context) *sql.Row | ||||||
|  | 	Rows(ctx context.Context) (*sql.Rows, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type JoinBuilder interface { | ||||||
|  | 	Select(...string) JoinBuilder | ||||||
|  | 	Omit(...string) JoinBuilder | ||||||
|  | 	Where(query interface{}, args ...interface{}) JoinBuilder | ||||||
|  | 	Not(query interface{}, args ...interface{}) JoinBuilder | ||||||
|  | 	Or(query interface{}, args ...interface{}) JoinBuilder | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PreloadBuilder interface { | ||||||
|  | 	Select(...string) PreloadBuilder | ||||||
|  | 	Omit(...string) PreloadBuilder | ||||||
|  | 	Where(query interface{}, args ...interface{}) PreloadBuilder | ||||||
|  | 	Not(query interface{}, args ...interface{}) PreloadBuilder | ||||||
|  | 	Or(query interface{}, args ...interface{}) PreloadBuilder | ||||||
|  | 	Limit(offset int) PreloadBuilder | ||||||
|  | 	Offset(offset int) PreloadBuilder | ||||||
|  | 	Order(value interface{}) PreloadBuilder | ||||||
|  | 	LimitPerRecord(num int) PreloadBuilder | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type op func(*DB) *DB | ||||||
|  | 
 | ||||||
|  | func G[T any](db *DB, opts ...clause.Expression) Interface[T] { | ||||||
|  | 	v := &g[T]{ | ||||||
|  | 		db:  db, | ||||||
|  | 		ops: make([]op, 0, 5), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(opts) > 0 { | ||||||
|  | 		v.ops = append(v.ops, func(db *DB) *DB { | ||||||
|  | 			return db.Clauses(opts...) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	v.createG = &createG[T]{ | ||||||
|  | 		chainG: chainG[T]{ | ||||||
|  | 			execG: execG[T]{g: v}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	return v | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type g[T any] struct { | ||||||
|  | 	*createG[T] | ||||||
|  | 	db  *DB | ||||||
|  | 	ops []op | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g *g[T]) apply(ctx context.Context) *DB { | ||||||
|  | 	db := g.db | ||||||
|  | 	if !db.DryRun { | ||||||
|  | 		db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, op := range g.ops { | ||||||
|  | 		db = op(db) | ||||||
|  | 	} | ||||||
|  | 	return db | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { | ||||||
|  | 	return execG[T]{g: &g[T]{ | ||||||
|  | 		db: c.db, | ||||||
|  | 		ops: append(c.ops, func(db *DB) *DB { | ||||||
|  | 			return db.Raw(sql, values...) | ||||||
|  | 		}), | ||||||
|  | 	}} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { | ||||||
|  | 	return c.apply(ctx).Exec(sql, values...).Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type createG[T any] struct { | ||||||
|  | 	chainG[T] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { | ||||||
|  | 	return createG[T]{c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Table(name, args...) | ||||||
|  | 	})} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c createG[T]) Create(ctx context.Context, r *T) error { | ||||||
|  | 	return c.g.apply(ctx).Create(r).Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { | ||||||
|  | 	return c.g.apply(ctx).CreateInBatches(r, batchSize).Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type chainG[T any] struct { | ||||||
|  | 	execG[T] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) getInstance() *DB { | ||||||
|  | 	var r T | ||||||
|  | 	return c.g.apply(context.Background()).Model(r).getInstance() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) with(v op) chainG[T] { | ||||||
|  | 	return chainG[T]{ | ||||||
|  | 		execG: execG[T]{g: &g[T]{ | ||||||
|  | 			db:  c.g.db, | ||||||
|  | 			ops: append(append([]op(nil), c.g.ops...), v), | ||||||
|  | 		}}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		for _, fc := range scopes { | ||||||
|  | 			fc(db.Statement) | ||||||
|  | 		} | ||||||
|  | 		return db | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Table(name, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Where(query, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Not(query, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Or(query, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Limit(offset int) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Limit(offset) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Offset(offset int) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Offset(offset) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type joinBuilder struct { | ||||||
|  | 	db *DB | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *joinBuilder) Select(columns ...string) JoinBuilder { | ||||||
|  | 	q.db.Select(columns) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *joinBuilder) Omit(columns ...string) JoinBuilder { | ||||||
|  | 	q.db.Omit(columns...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type preloadBuilder struct { | ||||||
|  | 	limitPerRecord int | ||||||
|  | 	db             *DB | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { | ||||||
|  | 	q.db.Where(query, args...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { | ||||||
|  | 	q.db.Select(columns) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { | ||||||
|  | 	q.db.Omit(columns...) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Limit(limit int) PreloadBuilder { | ||||||
|  | 	q.db.Limit(limit) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Offset(offset int) PreloadBuilder { | ||||||
|  | 	q.db.Offset(offset) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { | ||||||
|  | 	q.db.Order(value) | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { | ||||||
|  | 	q.limitPerRecord = num | ||||||
|  | 	return q | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		if jt.Table == "" { | ||||||
|  | 			jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} | ||||||
|  | 		if on != nil { | ||||||
|  | 			if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { | ||||||
|  | 				db.AddError(err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		j := join{ | ||||||
|  | 			Name:     jt.Association, | ||||||
|  | 			Alias:    jt.Table, | ||||||
|  | 			Selects:  q.db.Statement.Selects, | ||||||
|  | 			Omits:    q.db.Statement.Omits, | ||||||
|  | 			JoinType: jt.Type, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { | ||||||
|  | 			j.On = &where | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if jt.Subquery != nil { | ||||||
|  | 			joinType := j.JoinType | ||||||
|  | 			if joinType == "" { | ||||||
|  | 				joinType = clause.LeftJoin | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok { | ||||||
|  | 				stmt := db.getInstance().Statement | ||||||
|  | 				if len(j.Selects) == 0 { | ||||||
|  | 					j.Selects = stmt.Selects | ||||||
|  | 				} | ||||||
|  | 				if len(j.Omits) == 0 { | ||||||
|  | 					j.Omits = stmt.Omits | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} | ||||||
|  | 
 | ||||||
|  | 			if j.On != nil { | ||||||
|  | 				expr.SQL += " ON ?" | ||||||
|  | 				expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			j.Expression = expr | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		db.Statement.Joins = append(db.Statement.Joins, j) | ||||||
|  | 		sort.Slice(db.Statement.Joins, func(i, j int) bool { | ||||||
|  | 			return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name | ||||||
|  | 		}) | ||||||
|  | 		return db | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Select(query, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Omit(columns...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.MapColumns(m) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Distinct(args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Group(name string) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Group(name) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Having(query, args...) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Order(value interface{}) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Order(value) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { | ||||||
|  | 	return c.with(func(db *DB) *DB { | ||||||
|  | 		return db.Preload(association, func(tx *DB) *DB { | ||||||
|  | 			q := preloadBuilder{db: tx.getInstance()} | ||||||
|  | 			if query != nil { | ||||||
|  | 				if err := query(&q); err != nil { | ||||||
|  | 					db.AddError(err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			relation, ok := db.Statement.Schema.Relationships.Relations[association] | ||||||
|  | 			if !ok { | ||||||
|  | 				if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { | ||||||
|  | 					relationships := db.Statement.Schema.Relationships | ||||||
|  | 					for _, field := range preloadFields { | ||||||
|  | 						var ok bool | ||||||
|  | 						relation, ok = relationships.Relations[field] | ||||||
|  | 						if ok { | ||||||
|  | 							relationships = relation.FieldSchema.Relationships | ||||||
|  | 						} else { | ||||||
|  | 							db.AddError(fmt.Errorf("relation %s not found", association)) | ||||||
|  | 							return nil | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					db.AddError(fmt.Errorf("relation %s not found", association)) | ||||||
|  | 					return nil | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if q.limitPerRecord > 0 { | ||||||
|  | 				if relation.JoinTable != nil { | ||||||
|  | 					tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) | ||||||
|  | 					return tx | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				refColumns := []clause.Column{} | ||||||
|  | 				for _, rel := range relation.References { | ||||||
|  | 					if rel.OwnPrimaryKey { | ||||||
|  | 						refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				if len(refColumns) != 0 { | ||||||
|  | 					selectExpr := clause.CommaExpression{} | ||||||
|  | 					for _, column := range q.db.Statement.Selects { | ||||||
|  | 						selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					if len(selectExpr.Exprs) == 0 { | ||||||
|  | 						selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					partitionBy := clause.CommaExpression{} | ||||||
|  | 					for _, column := range refColumns { | ||||||
|  | 						partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					rnnColumn := clause.Column{Name: "gorm_preload_rnn"} | ||||||
|  | 					sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" | ||||||
|  | 					vars := []interface{}{partitionBy} | ||||||
|  | 					if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { | ||||||
|  | 						vars = append(vars, orderBy) | ||||||
|  | 					} else { | ||||||
|  | 						vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ | ||||||
|  | 							Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, | ||||||
|  | 						}}) | ||||||
|  | 					} | ||||||
|  | 					vars = append(vars, rnnColumn) | ||||||
|  | 
 | ||||||
|  | 					selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) | ||||||
|  | 
 | ||||||
|  | 					q.db.Clauses(clause.Select{Expression: selectExpr}) | ||||||
|  | 
 | ||||||
|  | 					return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			return q.db | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { | ||||||
|  | 	r := new(T) | ||||||
|  | 	res := c.g.apply(ctx).Delete(r) | ||||||
|  | 	return int(res.RowsAffected), res.Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { | ||||||
|  | 	var r T | ||||||
|  | 	res := c.g.apply(ctx).Model(r).Update(name, value) | ||||||
|  | 	return int(res.RowsAffected), res.Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { | ||||||
|  | 	res := c.g.apply(ctx).Updates(t) | ||||||
|  | 	return int(res.RowsAffected), res.Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { | ||||||
|  | 	var r T | ||||||
|  | 	err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c chainG[T]) Build(builder clause.Builder) { | ||||||
|  | 	subdb := c.getInstance() | ||||||
|  | 	subdb.Logger = logger.Discard | ||||||
|  | 	subdb.DryRun = true | ||||||
|  | 
 | ||||||
|  | 	if stmt, ok := builder.(*Statement); ok { | ||||||
|  | 		if subdb.Statement.SQL.Len() > 0 { | ||||||
|  | 			var ( | ||||||
|  | 				vars = subdb.Statement.Vars | ||||||
|  | 				sql  = subdb.Statement.SQL.String() | ||||||
|  | 			) | ||||||
|  | 
 | ||||||
|  | 			subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | ||||||
|  | 			for _, vv := range vars { | ||||||
|  | 				subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | ||||||
|  | 				bindvar := strings.Builder{} | ||||||
|  | 				subdb.BindVarTo(&bindvar, subdb.Statement, vv) | ||||||
|  | 				sql = strings.Replace(sql, bindvar.String(), "?", 1) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			subdb.Statement.SQL.Reset() | ||||||
|  | 			subdb.Statement.Vars = stmt.Vars | ||||||
|  | 			if strings.Contains(sql, "@") { | ||||||
|  | 				clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||||
|  | 			} else { | ||||||
|  | 				clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) | ||||||
|  | 			subdb.callbacks.Query().Execute(subdb) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		builder.WriteString(subdb.Statement.SQL.String()) | ||||||
|  | 		stmt.Vars = subdb.Statement.Vars | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type execG[T any] struct { | ||||||
|  | 	g *g[T] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) First(ctx context.Context) (T, error) { | ||||||
|  | 	var r T | ||||||
|  | 	err := g.g.apply(ctx).First(&r).Error | ||||||
|  | 	return r, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Scan(ctx context.Context, result interface{}) error { | ||||||
|  | 	var r T | ||||||
|  | 	err := g.g.apply(ctx).Model(r).Find(result).Error | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Last(ctx context.Context) (T, error) { | ||||||
|  | 	var r T | ||||||
|  | 	err := g.g.apply(ctx).Last(&r).Error | ||||||
|  | 	return r, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Take(ctx context.Context) (T, error) { | ||||||
|  | 	var r T | ||||||
|  | 	err := g.g.apply(ctx).Take(&r).Error | ||||||
|  | 	return r, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Find(ctx context.Context) ([]T, error) { | ||||||
|  | 	var r []T | ||||||
|  | 	err := g.g.apply(ctx).Find(&r).Error | ||||||
|  | 	return r, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { | ||||||
|  | 	var data []T | ||||||
|  | 	return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { | ||||||
|  | 		return fc(data, batch) | ||||||
|  | 	}).Error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Row(ctx context.Context) *sql.Row { | ||||||
|  | 	return g.g.apply(ctx).Row() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { | ||||||
|  | 	return g.g.apply(ctx).Rows() | ||||||
|  | } | ||||||
							
								
								
									
										14
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								gorm.go
									
									
									
									
									
								
							| @ -21,7 +21,9 @@ const preparedStmtDBKey = "preparedStmt" | |||||||
| type Config struct { | type Config struct { | ||||||
| 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | ||||||
| 	// You can disable it by setting `SkipDefaultTransaction` to true
 | 	// You can disable it by setting `SkipDefaultTransaction` to true
 | ||||||
| 	SkipDefaultTransaction bool | 	SkipDefaultTransaction    bool | ||||||
|  | 	DefaultTransactionTimeout time.Duration | ||||||
|  | 
 | ||||||
| 	// NamingStrategy tables, columns naming strategy
 | 	// NamingStrategy tables, columns naming strategy
 | ||||||
| 	NamingStrategy schema.Namer | 	NamingStrategy schema.Namer | ||||||
| 	// FullSaveAssociations full save associations
 | 	// FullSaveAssociations full save associations
 | ||||||
| @ -135,12 +137,16 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { | |||||||
| 		return isConfig && !isConfig2 | 		return isConfig && !isConfig2 | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
|  | 	var skipAfterInitialize bool | ||||||
| 	for _, opt := range opts { | 	for _, opt := range opts { | ||||||
| 		if opt != nil { | 		if opt != nil { | ||||||
| 			if applyErr := opt.Apply(config); applyErr != nil { | 			if applyErr := opt.Apply(config); applyErr != nil { | ||||||
| 				return nil, applyErr | 				return nil, applyErr | ||||||
| 			} | 			} | ||||||
| 			defer func(opt Option) { | 			defer func(opt Option) { | ||||||
|  | 				if skipAfterInitialize { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
| 				if errr := opt.AfterInitialize(db); errr != nil { | 				if errr := opt.AfterInitialize(db); errr != nil { | ||||||
| 					err = errr | 					err = errr | ||||||
| 				} | 				} | ||||||
| @ -192,6 +198,10 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { | |||||||
| 			if db, _ := db.DB(); db != nil { | 			if db, _ := db.DB(); db != nil { | ||||||
| 				_ = db.Close() | 				_ = db.Close() | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
|  | 			// DB is not initialized, so we skip AfterInitialize
 | ||||||
|  | 			skipAfterInitialize = true | ||||||
|  | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if config.TranslateError { | 		if config.TranslateError { | ||||||
| @ -519,7 +529,7 @@ func (db *DB) Use(plugin Plugin) error { | |||||||
| //				.First(&User{})
 | //				.First(&User{})
 | ||||||
| //	})
 | //	})
 | ||||||
| func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { | func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { | ||||||
| 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true})) | 	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) | ||||||
| 	stmt := tx.Statement | 	stmt := tx.Statement | ||||||
| 
 | 
 | ||||||
| 	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | 	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) | ||||||
|  | |||||||
							
								
								
									
										9
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								scan.go
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ import ( | |||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| @ -244,6 +245,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) { | |||||||
| 							matchedFieldCount[column] = 1 | 							matchedFieldCount[column] = 1 | ||||||
| 						} | 						} | ||||||
| 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | 					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | ||||||
|  | 						aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) | ||||||
|  | 						for _, join := range db.Statement.Joins { | ||||||
|  | 							if join.Alias == aliasName { | ||||||
|  | 								names = append(strings.Split(join.Name, "."), names[len(names)-1]) | ||||||
|  | 								break | ||||||
|  | 							} | ||||||
|  | 						} | ||||||
|  | 
 | ||||||
| 						if rel, ok := sch.Relationships.Relations[names[0]]; ok { | 						if rel, ok := sch.Relationships.Relations[names[0]]; ok { | ||||||
| 							subNameCount := len(names) | 							subNameCount := len(names) | ||||||
| 							// nested relation fields
 | 							// nested relation fields
 | ||||||
|  | |||||||
| @ -318,9 +318,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if val, ok := field.TagSettings["TYPE"]; ok { | 	if val, ok := field.TagSettings["TYPE"]; ok { | ||||||
| 		switch DataType(strings.ToLower(val)) { | 		lowerVal := DataType(strings.ToLower(val)) | ||||||
|  | 		switch lowerVal { | ||||||
| 		case Bool, Int, Uint, Float, String, Time, Bytes: | 		case Bool, Int, Uint, Float, String, Time, Bytes: | ||||||
| 			field.DataType = DataType(strings.ToLower(val)) | 			field.DataType = lowerVal | ||||||
| 		default: | 		default: | ||||||
| 			field.DataType = DataType(val) | 			field.DataType = DataType(val) | ||||||
| 		} | 		} | ||||||
|  | |||||||
| @ -105,7 +105,7 @@ func parseFieldIndexes(field *Field) (indexes []Index, err error) { | |||||||
| 				var ( | 				var ( | ||||||
| 					name       string | 					name       string | ||||||
| 					tag        = strings.Join(v[1:], ":") | 					tag        = strings.Join(v[1:], ":") | ||||||
| 					idx        = strings.Index(tag, ",") | 					idx        = strings.IndexByte(tag, ',') | ||||||
| 					tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") | 					tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") | ||||||
| 					settings   = ParseTagSetting(tagSetting, ",") | 					settings   = ParseTagSetting(tagSetting, ",") | ||||||
| 					length, _  = strconv.Atoi(settings["LENGTH"]) | 					length, _  = strconv.Atoi(settings["LENGTH"]) | ||||||
|  | |||||||
| @ -78,7 +78,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { | |||||||
| 	cacheStore := schema.cacheStore | 	cacheStore := schema.cacheStore | ||||||
| 
 | 
 | ||||||
| 	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { | 	if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil { | ||||||
| 		schema.err = err | 		schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -663,6 +663,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | |||||||
| 					if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && | 					if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey && | ||||||
| 						rel.References[idx].PrimaryValue == ref.PrimaryValue) { | 						rel.References[idx].PrimaryValue == ref.PrimaryValue) { | ||||||
| 						matched = false | 						matched = false | ||||||
|  | 						break | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| @ -675,7 +676,7 @@ func (rel *Relationship) ParseConstraint() *Constraint { | |||||||
| 
 | 
 | ||||||
| 	var ( | 	var ( | ||||||
| 		name     string | 		name     string | ||||||
| 		idx      = strings.Index(str, ",") | 		idx      = strings.IndexByte(str, ',') | ||||||
| 		settings = ParseTagSetting(str, ",") | 		settings = ParseTagSetting(str, ",") | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| @ -762,8 +763,9 @@ func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue ref | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func copyableDataType(str DataType) bool { | func copyableDataType(str DataType) bool { | ||||||
|  | 	lowerStr := strings.ToLower(string(str)) | ||||||
| 	for _, s := range []string{"auto_increment", "primary key"} { | 	for _, s := range []string{"auto_increment", "primary key"} { | ||||||
| 		if strings.Contains(strings.ToLower(string(str)), s) { | 		if strings.Contains(lowerStr, s) { | ||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										57
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								statement.go
									
									
									
									
									
								
							| @ -47,15 +47,18 @@ type Statement struct { | |||||||
| 	attrs                []interface{} | 	attrs                []interface{} | ||||||
| 	assigns              []interface{} | 	assigns              []interface{} | ||||||
| 	scopes               []func(*DB) *DB | 	scopes               []func(*DB) *DB | ||||||
|  | 	Result               *result | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type join struct { | type join struct { | ||||||
| 	Name     string | 	Name       string | ||||||
| 	Conds    []interface{} | 	Alias      string | ||||||
| 	On       *clause.Where | 	Conds      []interface{} | ||||||
| 	Selects  []string | 	On         *clause.Where | ||||||
| 	Omits    []string | 	Selects    []string | ||||||
| 	JoinType clause.JoinType | 	Omits      []string | ||||||
|  | 	Expression clause.Expression | ||||||
|  | 	JoinType   clause.JoinType | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // StatementModifier statement modifier interface
 | // StatementModifier statement modifier interface
 | ||||||
| @ -205,19 +208,21 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | |||||||
| 			} else { | 			} else { | ||||||
| 				writer.WriteString("(NULL)") | 				writer.WriteString("(NULL)") | ||||||
| 			} | 			} | ||||||
| 		case *DB: | 		case interface{ getInstance() *DB }: | ||||||
| 			subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() | 			cv := v.getInstance() | ||||||
| 			if v.Statement.SQL.Len() > 0 { | 
 | ||||||
|  | 			subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() | ||||||
|  | 			if cv.Statement.SQL.Len() > 0 { | ||||||
| 				var ( | 				var ( | ||||||
| 					vars = subdb.Statement.Vars | 					vars = subdb.Statement.Vars | ||||||
| 					sql  = v.Statement.SQL.String() | 					sql  = cv.Statement.SQL.String() | ||||||
| 				) | 				) | ||||||
| 
 | 
 | ||||||
| 				subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | 				subdb.Statement.Vars = make([]interface{}, 0, len(vars)) | ||||||
| 				for _, vv := range vars { | 				for _, vv := range vars { | ||||||
| 					subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | 					subdb.Statement.Vars = append(subdb.Statement.Vars, vv) | ||||||
| 					bindvar := strings.Builder{} | 					bindvar := strings.Builder{} | ||||||
| 					v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) | 					cv.BindVarTo(&bindvar, subdb.Statement, vv) | ||||||
| 					sql = strings.Replace(sql, bindvar.String(), "?", 1) | 					sql = strings.Replace(sql, bindvar.String(), "?", 1) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| @ -321,6 +326,11 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 			arg, _ = valuer.Value() | 			arg, _ = valuer.Value() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		curTable := stmt.Table | ||||||
|  | 		if curTable == "" { | ||||||
|  | 			curTable = clause.CurrentTable | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		switch v := arg.(type) { | 		switch v := arg.(type) { | ||||||
| 		case clause.Expression: | 		case clause.Expression: | ||||||
| 			conds = append(conds, v) | 			conds = append(conds, v) | ||||||
| @ -351,7 +361,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 			sort.Strings(keys) | 			sort.Strings(keys) | ||||||
| 
 | 
 | ||||||
| 			for _, key := range keys { | 			for _, key := range keys { | ||||||
| 				conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | 				column := clause.Column{Name: key, Table: curTable} | ||||||
|  | 				conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||||
| 			} | 			} | ||||||
| 		case map[string]interface{}: | 		case map[string]interface{}: | ||||||
| 			keys := make([]string, 0, len(v)) | 			keys := make([]string, 0, len(v)) | ||||||
| @ -362,12 +373,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 
 | 
 | ||||||
| 			for _, key := range keys { | 			for _, key := range keys { | ||||||
| 				reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) | 				reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) | ||||||
|  | 				column := clause.Column{Name: key, Table: curTable} | ||||||
| 				switch reflectValue.Kind() { | 				switch reflectValue.Kind() { | ||||||
| 				case reflect.Slice, reflect.Array: | 				case reflect.Slice, reflect.Array: | ||||||
| 					if _, ok := v[key].(driver.Valuer); ok { | 					if _, ok := v[key].(driver.Valuer); ok { | ||||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||||
| 					} else if _, ok := v[key].(Valuer); ok { | 					} else if _, ok := v[key].(Valuer); ok { | ||||||
| 						conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | 						conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||||
| 					} else { | 					} else { | ||||||
| 						// optimize reflect value length
 | 						// optimize reflect value length
 | ||||||
| 						valueLen := reflectValue.Len() | 						valueLen := reflectValue.Len() | ||||||
| @ -376,10 +388,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 							values[i] = reflectValue.Index(i).Interface() | 							values[i] = reflectValue.Index(i).Interface() | ||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						conds = append(conds, clause.IN{Column: key, Values: values}) | 						conds = append(conds, clause.IN{Column: column, Values: values}) | ||||||
| 					} | 					} | ||||||
| 				default: | 				default: | ||||||
| 					conds = append(conds, clause.Eq{Column: key, Value: v[key]}) | 					conds = append(conds, clause.Eq{Column: column, Value: v[key]}) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		default: | 		default: | ||||||
| @ -406,9 +418,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 						if selected || (!restricted && field.Readable) { | 						if selected || (!restricted && field.Readable) { | ||||||
| 							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { | 							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { | ||||||
| 								if field.DBName != "" { | 								if field.DBName != "" { | ||||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||||
| 								} else if field.DataType != "" { | 								} else if field.DataType != "" { | ||||||
| 									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | 									conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| @ -420,9 +432,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 							if selected || (!restricted && field.Readable) { | 							if selected || (!restricted && field.Readable) { | ||||||
| 								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { | 								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { | ||||||
| 									if field.DBName != "" { | 									if field.DBName != "" { | ||||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) | 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) | ||||||
| 									} else if field.DataType != "" { | 									} else if field.DataType != "" { | ||||||
| 										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) | 										conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) | ||||||
| 									} | 									} | ||||||
| 								} | 								} | ||||||
| 							} | 							} | ||||||
| @ -447,14 +459,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] | |||||||
| 						} | 						} | ||||||
| 
 | 
 | ||||||
| 						if len(values) > 0 { | 						if len(values) > 0 { | ||||||
| 							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) | 							conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) | ||||||
| 							return []clause.Expression{clause.And(conds...)} | 							return []clause.Expression{clause.And(conds...)} | ||||||
| 						} | 						} | ||||||
| 						return nil | 						return nil | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) | 				conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @ -521,6 +533,7 @@ func (stmt *Statement) clone() *Statement { | |||||||
| 		Context:              stmt.Context, | 		Context:              stmt.Context, | ||||||
| 		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, | 		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, | ||||||
| 		SkipHooks:            stmt.SkipHooks, | 		SkipHooks:            stmt.SkipHooks, | ||||||
|  | 		Result:               stmt.Result, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if stmt.SQL.Len() > 0 { | 	if stmt.SQL.Len() > 0 { | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ services: | |||||||
|       - POSTGRES_USER=gorm |       - POSTGRES_USER=gorm | ||||||
|       - POSTGRES_PASSWORD=gorm |       - POSTGRES_PASSWORD=gorm | ||||||
|   mssql: |   mssql: | ||||||
|     image: '${MSSQL_IMAGE}:2022-latest' |     image: '${MSSQL_IMAGE}:latest' | ||||||
|     ports: |     ports: | ||||||
|       - "127.0.0.1:9930:1433" |       - "127.0.0.1:9930:1433" | ||||||
|     environment: |     environment: | ||||||
|  | |||||||
| @ -119,6 +119,7 @@ func TestConnPoolWrapper(t *testing.T) { | |||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) | 	db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) | ||||||
|  | 	db.Logger = DB.Logger | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Should open db success, but got %v", err) | 		t.Fatalf("Should open db success, but got %v", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										875
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										875
									
								
								tests/generics_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,875 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"reflect" | ||||||
|  | 	"regexp" | ||||||
|  | 	"sort" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"gorm.io/driver/mysql" | ||||||
|  | 	"gorm.io/gorm" | ||||||
|  | 	"gorm.io/gorm/clause" | ||||||
|  | 	. "gorm.io/gorm/utils/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestGenericsCreate(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	user := User{Name: "TestGenericsCreate", Age: 18} | ||||||
|  | 	err := gorm.G[User](DB).Create(ctx, &user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Create failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if user.ID == 0 { | ||||||
|  | 		t.Fatalf("no primary key found for %v", user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if u.Name != user.Name || u.ID != user.ID { | ||||||
|  | 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if u.Name != user.Name || u.ID != user.ID { | ||||||
|  | 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if u.Name != user.Name || u.Age != 0 { | ||||||
|  | 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if u.Name != "" || u.Age != user.Age { | ||||||
|  | 		t.Errorf("found invalid user, got %v, expect %v", u, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	result := struct { | ||||||
|  | 		ID   int | ||||||
|  | 		Name string | ||||||
|  | 	}{} | ||||||
|  | 	if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { | ||||||
|  | 		t.Fatalf("failed to scan user, got error: %v", err) | ||||||
|  | 	} else if result.Name != user.Name || uint(result.ID) != user.ID { | ||||||
|  | 		t.Errorf("found invalid user, got %v, expect %v", result, user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) | ||||||
|  | 	if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { | ||||||
|  | 		t.Errorf("failed to find map results, got %v, err %v", mapResult, err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsCreateInBatches(t *testing.T) { | ||||||
|  | 	batch := []User{ | ||||||
|  | 		{Name: "GenericsCreateInBatches1"}, | ||||||
|  | 		{Name: "GenericsCreateInBatches2"}, | ||||||
|  | 		{Name: "GenericsCreateInBatches3"}, | ||||||
|  | 	} | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, u := range batch { | ||||||
|  | 		if u.ID == 0 { | ||||||
|  | 			t.Fatalf("no primary key found for %v", u) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Count failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if count != 3 { | ||||||
|  | 		t.Errorf("expected 3 records, got %d", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) | ||||||
|  | 	if len(found) != len(batch) { | ||||||
|  | 		t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx) | ||||||
|  | 	if len(found) != 2 { | ||||||
|  | 		t.Errorf("expected %d from Raw Find, got %d", 2, len(found)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx) | ||||||
|  | 	if len(found) != 1 { | ||||||
|  | 		t.Errorf("expected %d from Raw Find, got %d", 1, len(found)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsExecAndUpdate(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	name := "GenericsExec" | ||||||
|  | 	if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { | ||||||
|  | 		t.Fatalf("Exec insert failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if u.Name != name || u.ID == 0 { | ||||||
|  | 		t.Errorf("found invalid user, got %v", u) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	name += "Update" | ||||||
|  | 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) | ||||||
|  | 	if rows != 1 { | ||||||
|  | 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if nu.Name != name || u.ID != nu.ID { | ||||||
|  | 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) | ||||||
|  | 	if rows != 1 { | ||||||
|  | 		t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to find user, got error: %v", err) | ||||||
|  | 	} else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { | ||||||
|  | 		t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsRow(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	user := User{Name: "GenericsRow"} | ||||||
|  | 	if err := gorm.G[User](DB).Create(ctx, &user); err != nil { | ||||||
|  | 		t.Fatalf("Create failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	row := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id = ?", user.ID).Row(ctx) | ||||||
|  | 	var name string | ||||||
|  | 	if err := row.Scan(&name); err != nil { | ||||||
|  | 		t.Fatalf("Row scan failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if name != user.Name { | ||||||
|  | 		t.Errorf("expected %s, got %s", user.Name, name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user2 := User{Name: "GenericsRow2"} | ||||||
|  | 	if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { | ||||||
|  | 		t.Fatalf("Create failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	rows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Rows failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	count := 0 | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		var name string | ||||||
|  | 		if err := rows.Scan(&name); err != nil { | ||||||
|  | 			t.Fatalf("rows.Scan failed: %v", err) | ||||||
|  | 		} | ||||||
|  | 		count++ | ||||||
|  | 	} | ||||||
|  | 	if count != 2 { | ||||||
|  | 		t.Errorf("expected 2 rows, got %d", count) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsDelete(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	u := User{Name: "GenericsDelete"} | ||||||
|  | 	if err := gorm.G[User](DB).Create(ctx, &u); err != nil { | ||||||
|  | 		t.Fatalf("Create failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Delete failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if rows != 1 { | ||||||
|  | 		t.Errorf("expected 1 row deleted, got %d", rows) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) | ||||||
|  | 	if err != gorm.ErrRecordNotFound { | ||||||
|  | 		t.Fatalf("User after delete failed: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsFindInBatches(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	users := []User{ | ||||||
|  | 		{Name: "GenericsFindBatchA"}, | ||||||
|  | 		{Name: "GenericsFindBatchB"}, | ||||||
|  | 		{Name: "GenericsFindBatchC"}, | ||||||
|  | 		{Name: "GenericsFindBatchD"}, | ||||||
|  | 		{Name: "GenericsFindBatchE"}, | ||||||
|  | 	} | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	total := 0 | ||||||
|  | 	err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error { | ||||||
|  | 		if len(chunk) > 2 { | ||||||
|  | 			t.Errorf("batch size exceed 2: got %d", len(chunk)) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		total += len(chunk) | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("FindInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if total != len(users) { | ||||||
|  | 		t.Errorf("expected total %d, got %d", len(users), total) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsScopes(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}} | ||||||
|  | 	err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	filterName1 := func(stmt *gorm.Statement) { | ||||||
|  | 		stmt.Where("name = ?", "GenericsScopes1") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Scopes failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if len(results) != 1 || results[0].Name != "GenericsScopes1" { | ||||||
|  | 		t.Fatalf("Scopes expected 1, got %d", len(results)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx) | ||||||
|  | 	if len(notResult) != 2 { | ||||||
|  | 		t.Fatalf("expected 2 results, got %d", len(notResult)) | ||||||
|  | 	} else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" { | ||||||
|  | 		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx) | ||||||
|  | 	if len(orResult) != 2 { | ||||||
|  | 		t.Fatalf("expected 2 results, got %d", len(notResult)) | ||||||
|  | 	} else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" { | ||||||
|  | 		t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsJoins(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	db := gorm.G[User](DB) | ||||||
|  | 
 | ||||||
|  | 	u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} | ||||||
|  | 	u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} | ||||||
|  | 	u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} | ||||||
|  | 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||||
|  | 
 | ||||||
|  | 	// Inner JOIN + WHERE
 | ||||||
|  | 	result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||||
|  | 		return nil | ||||||
|  | 	}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Joins failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Inner JOIN + WHERE with map
 | ||||||
|  | 	result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 		db.Where(map[string]any{"name": u.Company.Name}) | ||||||
|  | 		return nil | ||||||
|  | 	}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Joins failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Left JOIN w/o WHERE
 | ||||||
|  | 	result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Joins failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Left JOIN + Alias WHERE
 | ||||||
|  | 	result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 		if joinTable.Name != "t" { | ||||||
|  | 			t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||||
|  | 		} | ||||||
|  | 		db.Where("?.name = ?", joinTable, u.Company.Name) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where(map[string]any{"name": u.Name}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Joins failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u.Name || result.Company.Name != u.Company.Name { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Raw Subquery JOIN + WHERE
 | ||||||
|  | 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), | ||||||
|  | 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 			if joinTable.Name != "t" { | ||||||
|  | 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||||
|  | 			} | ||||||
|  | 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Raw subquery join failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Raw Subquery JOIN + WHERE + Select
 | ||||||
|  | 	result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), | ||||||
|  | 		func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 			if joinTable.Name != "t" { | ||||||
|  | 				t.Fatalf("Join table should be t, but got %v", joinTable.Name) | ||||||
|  | 			} | ||||||
|  | 			db.Where("?.name = ?", joinTable, u.Company.Name) | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	).Where(map[string]any{"name": u2.Name}).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Raw subquery join failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 { | ||||||
|  | 		t.Fatalf("Joins expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { | ||||||
|  | 		return errors.New("join error") | ||||||
|  | 	}).First(ctx) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("Joins should got error, but got nil") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsNestedJoins(t *testing.T) { | ||||||
|  | 	users := []User{ | ||||||
|  | 		{ | ||||||
|  | 			Name: "generics-nested-joins-1", | ||||||
|  | 			Manager: &User{ | ||||||
|  | 				Name: "generics-nested-joins-manager-1", | ||||||
|  | 				Company: Company{ | ||||||
|  | 					Name: "generics-nested-joins-manager-company-1", | ||||||
|  | 				}, | ||||||
|  | 				NamedPet: &Pet{ | ||||||
|  | 					Name: "generics-nested-joins-manager-namepet-1", | ||||||
|  | 					Toy: Toy{ | ||||||
|  | 						Name: "generics-nested-joins-manager-namepet-toy-1", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Name:     "generics-nested-joins-2", | ||||||
|  | 			Manager:  GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), | ||||||
|  | 			NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	db := gorm.G[User](DB) | ||||||
|  | 	db.CreateInBatches(ctx, &users, 100) | ||||||
|  | 
 | ||||||
|  | 	var userIDs []uint | ||||||
|  | 	for _, user := range users { | ||||||
|  | 		userIDs = append(userIDs, user.ID) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). | ||||||
|  | 		Joins(clause.LeftJoin.Association("Manager.Company"), nil). | ||||||
|  | 		Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). | ||||||
|  | 		Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). | ||||||
|  | 		Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). | ||||||
|  | 		Where(map[string]any{"id": userIDs}).Find(ctx) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to load with joins, got error: %v", err) | ||||||
|  | 	} else if len(users2) != len(users) { | ||||||
|  | 		t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	sort.Slice(users2, func(i, j int) bool { | ||||||
|  | 		return users2[i].ID > users2[j].ID | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	sort.Slice(users, func(i, j int) bool { | ||||||
|  | 		return users[i].ID > users[j].ID | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	for idx, user := range users { | ||||||
|  | 		// user
 | ||||||
|  | 		CheckUser(t, user, users2[idx]) | ||||||
|  | 		if users2[idx].Manager == nil { | ||||||
|  | 			t.Fatalf("Failed to load Manager") | ||||||
|  | 		} | ||||||
|  | 		// manager
 | ||||||
|  | 		CheckUser(t, *user.Manager, *users2[idx].Manager) | ||||||
|  | 		// user pet
 | ||||||
|  | 		if users2[idx].NamedPet == nil { | ||||||
|  | 			t.Fatalf("Failed to load NamedPet") | ||||||
|  | 		} | ||||||
|  | 		CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) | ||||||
|  | 		// manager pet
 | ||||||
|  | 		if users2[idx].Manager.NamedPet == nil { | ||||||
|  | 			t.Fatalf("Failed to load NamedPet") | ||||||
|  | 		} | ||||||
|  | 		CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsPreloads(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	db := gorm.G[User](DB) | ||||||
|  | 
 | ||||||
|  | 	u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) | ||||||
|  | 	u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) | ||||||
|  | 	u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) | ||||||
|  | 	names := []string{u.Name, u2.Name, u3.Name} | ||||||
|  | 
 | ||||||
|  | 	db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) | ||||||
|  | 
 | ||||||
|  | 	result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Preload failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { | ||||||
|  | 		t.Fatalf("Preload expected %s, got %+v", u.Name, result) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.Where("name = ?", u.Company.Name) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where("name in ?", names).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Preload failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	for _, result := range results { | ||||||
|  | 		if result.Name == u.Name { | ||||||
|  | 			if result.Company.Name != u.Company.Name { | ||||||
|  | 				t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||||
|  | 			} | ||||||
|  | 		} else if result.Company.Name != "" { | ||||||
|  | 			t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	_, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		return errors.New("preload error") | ||||||
|  | 	}).Where("name in ?", names).Find(ctx) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("Preload should failed, but got nil") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Dialector.Name() == "mysql" { | ||||||
|  | 		// mysql 5.7 doesn't support row_number()
 | ||||||
|  | 		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.LimitPerRecord(5) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where("name in ?", names).Find(ctx) | ||||||
|  | 
 | ||||||
|  | 	for _, result := range results { | ||||||
|  | 		if result.Name == u.Name { | ||||||
|  | 			if len(result.Pets) != len(u.Pets) { | ||||||
|  | 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||||
|  | 			} | ||||||
|  | 		} else if len(result.Pets) != 5 { | ||||||
|  | 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Dialector.Name() == "sqlserver" { | ||||||
|  | 		// sqlserver doesn't support order by in subquery
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.Order("name desc").LimitPerRecord(5) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where("name in ?", names).Find(ctx) | ||||||
|  | 
 | ||||||
|  | 	for _, result := range results { | ||||||
|  | 		if result.Name == u.Name { | ||||||
|  | 			if len(result.Pets) != len(u.Pets) { | ||||||
|  | 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||||
|  | 			} | ||||||
|  | 		} else if len(result.Pets) != 5 { | ||||||
|  | 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||||
|  | 		} | ||||||
|  | 		for i := 1; i < len(result.Pets); i++ { | ||||||
|  | 			if result.Pets[i-1].Name < result.Pets[i].Name { | ||||||
|  | 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.Order("name").LimitPerRecord(5) | ||||||
|  | 		return nil | ||||||
|  | 	}).Preload("Friends", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.Order("name") | ||||||
|  | 		return nil | ||||||
|  | 	}).Where("name in ?", names).Find(ctx) | ||||||
|  | 
 | ||||||
|  | 	for _, result := range results { | ||||||
|  | 		if result.Name == u.Name { | ||||||
|  | 			if len(result.Pets) != len(u.Pets) { | ||||||
|  | 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||||
|  | 			} | ||||||
|  | 			if len(result.Friends) != len(u.Friends) { | ||||||
|  | 				t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) | ||||||
|  | 			} | ||||||
|  | 		} else if len(result.Pets) != 5 || len(result.Friends) == 0 { | ||||||
|  | 			t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) | ||||||
|  | 		} | ||||||
|  | 		for i := 1; i < len(result.Pets); i++ { | ||||||
|  | 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||||
|  | 				t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		for i := 1; i < len(result.Pets); i++ { | ||||||
|  | 			if result.Pets[i-1].Name > result.Pets[i].Name { | ||||||
|  | 				t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsNestedPreloads(t *testing.T) { | ||||||
|  | 	user := *GetUser("generics_nested_preload", Config{Pets: 2}) | ||||||
|  | 	user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} | ||||||
|  | 
 | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	db := gorm.G[User](DB) | ||||||
|  | 
 | ||||||
|  | 	for idx, pet := range user.Pets { | ||||||
|  | 		pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := db.Create(ctx, &user); err != nil { | ||||||
|  | 		t.Fatalf("errors happened when create: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		return nil | ||||||
|  | 	}).Where(user.ID).Take(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to nested preload user") | ||||||
|  | 	} | ||||||
|  | 	CheckUser(t, user2, user) | ||||||
|  | 	if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { | ||||||
|  | 		t.Fatalf("failed to nested preload") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Dialector.Name() == "mysql" { | ||||||
|  | 		// mysql 5.7 doesn't support row_number()
 | ||||||
|  | 		if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if DB.Dialector.Name() == "sqlserver" { | ||||||
|  | 		// sqlserver doesn't support order by in subquery
 | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { | ||||||
|  | 		db.LimitPerRecord(3) | ||||||
|  | 		return nil | ||||||
|  | 	}).Where(user.ID).Take(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to nested preload user") | ||||||
|  | 	} | ||||||
|  | 	CheckUser(t, user3, user) | ||||||
|  | 
 | ||||||
|  | 	if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { | ||||||
|  | 		t.Errorf("failed to nested preload with limit per record") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsDistinct(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	batch := []User{ | ||||||
|  | 		{Name: "GenericsDistinctDup"}, | ||||||
|  | 		{Name: "GenericsDistinctDup"}, | ||||||
|  | 		{Name: "GenericsDistinctUnique"}, | ||||||
|  | 	} | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Distinct Find failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(results) != 2 { | ||||||
|  | 		t.Errorf("expected 2 distinct names, got %d", len(results)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var names []string | ||||||
|  | 	for _, u := range results { | ||||||
|  | 		names = append(names, u.Name) | ||||||
|  | 	} | ||||||
|  | 	sort.Strings(names) | ||||||
|  | 	expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"} | ||||||
|  | 	if !reflect.DeepEqual(names, expected) { | ||||||
|  | 		t.Errorf("expected names %v, got %v", expected, names) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsGroupHaving(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 
 | ||||||
|  | 	batch := []User{ | ||||||
|  | 		{Name: "GenericsGroupHavingMulti"}, | ||||||
|  | 		{Name: "GenericsGroupHavingMulti"}, | ||||||
|  | 		{Name: "GenericsGroupHavingSingle"}, | ||||||
|  | 	} | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Group+Having Find failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(grouped) != 1 { | ||||||
|  | 		t.Errorf("expected 1 group with count>1, got %d", len(grouped)) | ||||||
|  | 	} else if grouped[0].Name != "GenericsGroupHavingMulti" { | ||||||
|  | 		t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsSubQuery(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	users := []User{ | ||||||
|  | 		{Name: "GenericsSubquery_1", Age: 10}, | ||||||
|  | 		{Name: "GenericsSubquery_2", Age: 20}, | ||||||
|  | 		{Name: "GenericsSubquery_3", Age: 30}, | ||||||
|  | 		{Name: "GenericsSubquery_4", Age: 40}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(results) != 4 { | ||||||
|  | 		t.Errorf("Four users should be found, instead found %d", len(results)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(results) != 3 { | ||||||
|  | 		t.Errorf("Three users should be found, instead found %d", len(results)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsUpsert(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	lang := Language{Code: "upsert", Name: "Upsert"} | ||||||
|  | 
 | ||||||
|  | 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { | ||||||
|  | 		t.Fatalf("failed to upsert, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	lang2 := Language{Code: "upsert", Name: "Upsert"} | ||||||
|  | 	if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { | ||||||
|  | 		t.Fatalf("failed to upsert, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||||
|  | 	} else if len(langs) != 1 { | ||||||
|  | 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	lang3 := Language{Code: "upsert", Name: "Upsert"} | ||||||
|  | 	if err := gorm.G[Language](DB, clause.OnConflict{ | ||||||
|  | 		Columns:   []clause.Column{{Name: "code"}}, | ||||||
|  | 		DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), | ||||||
|  | 	}).Create(ctx, &lang3); err != nil { | ||||||
|  | 		t.Fatalf("failed to upsert, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { | ||||||
|  | 		t.Errorf("no error should happen when find languages with code, but got %v", err) | ||||||
|  | 	} else if len(langs) != 1 { | ||||||
|  | 		t.Errorf("should only find only 1 languages, but got %+v", langs) | ||||||
|  | 	} else if langs[0].Name != "upsert-new" { | ||||||
|  | 		t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsWithResult(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} | ||||||
|  | 
 | ||||||
|  | 	result := gorm.WithResult() | ||||||
|  | 	err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to create users WithResult") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if result.RowsAffected != 2 { | ||||||
|  | 		t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsReuse(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}} | ||||||
|  | 
 | ||||||
|  | 	err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("failed to create users") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%") | ||||||
|  | 
 | ||||||
|  | 	sg := sync.WaitGroup{} | ||||||
|  | 	for i := 0; i < 5; i++ { | ||||||
|  | 		sg.Add(1) | ||||||
|  | 
 | ||||||
|  | 		go func() { | ||||||
|  | 			if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil { | ||||||
|  | 				t.Errorf("failed to find user, got error: %v", err) | ||||||
|  | 			} else if u1.Name != users[0].Name || u1.ID != users[0].ID { | ||||||
|  | 				t.Errorf("found invalid user, got %v, expect %v", u1, users[0]) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil { | ||||||
|  | 				t.Errorf("failed to find user, got error: %v", err) | ||||||
|  | 			} else if u2.Name != users[1].Name || u2.ID != users[1].ID { | ||||||
|  | 				t.Errorf("found invalid user, got %v, expect %v", u2, users[1]) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil { | ||||||
|  | 				t.Errorf("failed to find user, got error: %v", err) | ||||||
|  | 			} else if len(users) != 2 { | ||||||
|  | 				t.Errorf("should find 2 users, but got %d", len(users)) | ||||||
|  | 			} | ||||||
|  | 			sg.Done() | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  | 	sg.Wait() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsWithTransaction(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	tx := DB.Begin() | ||||||
|  | 	if tx.Error != nil { | ||||||
|  | 		t.Fatalf("failed to begin transaction: %v", tx.Error) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} | ||||||
|  | 	err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) | ||||||
|  | 
 | ||||||
|  | 	count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Count failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if count != 2 { | ||||||
|  | 		t.Errorf("expected 2 records, got %d", count) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := tx.Rollback().Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to rollback transaction: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Count failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if count2 != 0 { | ||||||
|  | 		t.Errorf("expected 0 records after rollback, got %d", count2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsToSQL(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { | ||||||
|  | 		gorm.G[User](tx).Limit(10).Find(ctx) | ||||||
|  | 		return tx | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) { | ||||||
|  | 		t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestGenericsScanUUID(t *testing.T) { | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	users := []User{ | ||||||
|  | 		{Name: uuid.NewString(), Age: 21}, | ||||||
|  | 		{Name: uuid.NewString(), Age: 22}, | ||||||
|  | 		{Name: uuid.NewString(), Age: 23}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil { | ||||||
|  | 		t.Fatalf("CreateInBatches failed: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	userIds := []uuid.UUID{} | ||||||
|  | 	if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 { | ||||||
|  | 		t.Fatalf("Scan failed: %v, userids %v", err, userIds) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name { | ||||||
|  | 		t.Fatalf("wrong uuid scanned") | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								tests/go.mod
									
									
									
									
									
								
							| @ -2,15 +2,13 @@ module gorm.io/gorm/tests | |||||||
| 
 | 
 | ||||||
| go 1.23.0 | go 1.23.0 | ||||||
| 
 | 
 | ||||||
| toolchain go1.24.2 |  | ||||||
| 
 |  | ||||||
| require ( | require ( | ||||||
| 	github.com/google/uuid v1.6.0 | 	github.com/google/uuid v1.6.0 | ||||||
| 	github.com/jinzhu/now v1.1.5 | 	github.com/jinzhu/now v1.1.5 | ||||||
| 	github.com/lib/pq v1.10.9 | 	github.com/lib/pq v1.10.9 | ||||||
| 	github.com/stretchr/testify v1.10.0 | 	github.com/stretchr/testify v1.10.0 | ||||||
| 	gorm.io/driver/mysql v1.5.7 | 	gorm.io/driver/mysql v1.5.7 | ||||||
| 	gorm.io/driver/postgres v1.5.11 | 	gorm.io/driver/postgres v1.6.0 | ||||||
| 	gorm.io/driver/sqlite v1.5.7 | 	gorm.io/driver/sqlite v1.5.7 | ||||||
| 	gorm.io/driver/sqlserver v1.6.0 | 	gorm.io/driver/sqlserver v1.6.0 | ||||||
| 	gorm.io/gorm v1.30.0 | 	gorm.io/gorm v1.30.0 | ||||||
| @ -25,6 +23,7 @@ require ( | |||||||
| 	github.com/jackc/pgpassfile v1.0.0 // indirect | 	github.com/jackc/pgpassfile v1.0.0 // indirect | ||||||
| 	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect | 	github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect | ||||||
| 	github.com/jackc/pgx/v5 v5.7.5 // indirect | 	github.com/jackc/pgx/v5 v5.7.5 // indirect | ||||||
|  | 	github.com/jackc/puddle/v2 v2.2.2 // indirect | ||||||
| 	github.com/jinzhu/inflection v1.0.0 // indirect | 	github.com/jinzhu/inflection v1.0.0 // indirect | ||||||
| 	github.com/kr/text v0.2.0 // indirect | 	github.com/kr/text v0.2.0 // indirect | ||||||
| 	github.com/mattn/go-sqlite3 v1.14.28 // indirect | 	github.com/mattn/go-sqlite3 v1.14.28 // indirect | ||||||
| @ -32,12 +31,9 @@ require ( | |||||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||||
| 	github.com/rogpeppe/go-internal v1.12.0 // indirect | 	github.com/rogpeppe/go-internal v1.12.0 // indirect | ||||||
| 	golang.org/x/crypto v0.38.0 // indirect | 	golang.org/x/crypto v0.38.0 // indirect | ||||||
|  | 	golang.org/x/sync v0.14.0 // indirect | ||||||
| 	golang.org/x/text v0.25.0 // indirect | 	golang.org/x/text v0.25.0 // indirect | ||||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| replace gorm.io/gorm => ../ | replace gorm.io/gorm => ../ | ||||||
| 
 |  | ||||||
| replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3 |  | ||||||
| 
 |  | ||||||
| replace github.com/microsoft/go-mssqldb => github.com/microsoft/go-mssqldb v1.7.0 |  | ||||||
|  | |||||||
| @ -419,7 +419,7 @@ func TestJoinsPreload_Issue7013(t *testing.T) { | |||||||
| 	var entries []User | 	var entries []User | ||||||
| 	assert.NotPanics(t, func() { | 	assert.NotPanics(t, func() { | ||||||
| 		assert.NoError(t, | 		assert.NoError(t, | ||||||
| 			DB.Debug().Preload("Manager.Team"). | 			DB.Preload("Manager.Team"). | ||||||
| 				Joins("Manager.Company"). | 				Joins("Manager.Company"). | ||||||
| 				Find(&entries).Error) | 				Find(&entries).Error) | ||||||
| 	}) | 	}) | ||||||
| @ -456,7 +456,7 @@ func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) { | |||||||
| 	var entries []Building | 	var entries []Building | ||||||
| 	assert.NotPanics(t, func() { | 	assert.NotPanics(t, func() { | ||||||
| 		assert.NoError(t, | 		assert.NoError(t, | ||||||
| 			DB.Debug().Preload("Owner.Furnitures"). | 			DB.Preload("Owner.Furnitures"). | ||||||
| 				Joins("Owner.Company"). | 				Joins("Owner.Company"). | ||||||
| 				Find(&entries).Error) | 				Find(&entries).Error) | ||||||
| 	}) | 	}) | ||||||
| @ -468,7 +468,7 @@ func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { | |||||||
| 	var entries []User | 	var entries []User | ||||||
| 	assert.NotPanics(t, func() { | 	assert.NotPanics(t, func() { | ||||||
| 		assert.NoError(t, | 		assert.NoError(t, | ||||||
| 			DB.Debug().Preload("Manager.Team"). | 			DB.Preload("Manager.Team"). | ||||||
| 				Joins("Manager.Company"). | 				Joins("Manager.Company"). | ||||||
| 				Where("1 <> 1"). | 				Where("1 <> 1"). | ||||||
| 				Find(&entries).Error) | 				Find(&entries).Error) | ||||||
|  | |||||||
| @ -520,38 +520,6 @@ func TestLRURemoveOldest(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func ExampleLRU() { |  | ||||||
| 	// make cache with 10ms TTL and 5 max keys
 |  | ||||||
| 	cache := lru.NewLRU[string, string](5, nil, time.Millisecond*10) |  | ||||||
| 
 |  | ||||||
| 	// set value under key1.
 |  | ||||||
| 	cache.Add("key1", "val1") |  | ||||||
| 
 |  | ||||||
| 	// get value under key1
 |  | ||||||
| 	r, ok := cache.Get("key1") |  | ||||||
| 
 |  | ||||||
| 	// check for OK value
 |  | ||||||
| 	if ok { |  | ||||||
| 		fmt.Printf("value before expiration is found: %v, value: %q\n", ok, r) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// wait for cache to expire
 |  | ||||||
| 	time.Sleep(time.Millisecond * 100) |  | ||||||
| 
 |  | ||||||
| 	// get value under key1 after key expiration
 |  | ||||||
| 	r, ok = cache.Get("key1") |  | ||||||
| 	fmt.Printf("value after expiration is found: %v, value: %q\n", ok, r) |  | ||||||
| 
 |  | ||||||
| 	// set value under key2, would evict old entry because it is already expired.
 |  | ||||||
| 	cache.Add("key2", "val2") |  | ||||||
| 
 |  | ||||||
| 	fmt.Printf("Cache len: %d\n", cache.Len()) |  | ||||||
| 	// Output:
 |  | ||||||
| 	// value before expiration is found: true, value: "val1"
 |  | ||||||
| 	// value after expiration is found: false, value: ""
 |  | ||||||
| 	// Cache len: 1
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func getRand(tb testing.TB) int64 { | func getRand(tb testing.TB) int64 { | ||||||
| 	out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) | 	out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -5,7 +5,6 @@ import ( | |||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"os" |  | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| @ -17,7 +16,6 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/logger" |  | ||||||
| 	"gorm.io/gorm/migrator" | 	"gorm.io/gorm/migrator" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| 	"gorm.io/gorm/utils" | 	"gorm.io/gorm/utils" | ||||||
| @ -1211,101 +1209,6 @@ func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestInvalidCachedPlanPrepareStmt(t *testing.T) { |  | ||||||
| 	if DB.Dialector.Name() != "postgres" { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("Open err:%v", err) |  | ||||||
| 	} |  | ||||||
| 	if debug := os.Getenv("DEBUG"); debug == "true" { |  | ||||||
| 		db.Logger = db.Logger.LogMode(logger.Info) |  | ||||||
| 	} else if debug == "false" { |  | ||||||
| 		db.Logger = db.Logger.LogMode(logger.Silent) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	type Object1 struct { |  | ||||||
| 		ID uint |  | ||||||
| 	} |  | ||||||
| 	type Object2 struct { |  | ||||||
| 		ID     uint |  | ||||||
| 		Field1 int `gorm:"type:int8"` |  | ||||||
| 	} |  | ||||||
| 	type Object3 struct { |  | ||||||
| 		ID     uint |  | ||||||
| 		Field1 int `gorm:"type:int4"` |  | ||||||
| 	} |  | ||||||
| 	type Object4 struct { |  | ||||||
| 		ID     uint |  | ||||||
| 		Field2 int |  | ||||||
| 	} |  | ||||||
| 	db.Migrator().DropTable("objects") |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").AutoMigrate(&Object1{}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("AutoMigrate err:%v", err) |  | ||||||
| 	} |  | ||||||
| 	err = db.Table("objects").Create(&Object1{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("create err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// AddColumn
 |  | ||||||
| 	err = db.Table("objects").AutoMigrate(&Object2{}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("AutoMigrate err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").Take(&Object2{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("take err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// AlterColumn
 |  | ||||||
| 	err = db.Table("objects").AutoMigrate(&Object3{}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("AutoMigrate err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").Take(&Object3{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("take err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// AddColumn
 |  | ||||||
| 	err = db.Table("objects").AutoMigrate(&Object4{}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("AutoMigrate err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").Take(&Object4{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("take err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("RenameColumn err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").Take(&Object4{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("take err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	db.Table("objects").Migrator().DropColumn(&Object4{}, "field3") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("RenameColumn err:%v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = db.Table("objects").Take(&Object4{}).Error |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("take err:%v", err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { | func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { | ||||||
| 	type DiffType struct { | 	type DiffType struct { | ||||||
| 		ID   uint | 		ID   uint | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| @ -459,7 +460,6 @@ func TestTransactionWithHooks(t *testing.T) { | |||||||
| 			return tx2.Scan(&User{}).Error | 			return tx2.Scan(&User{}).Error | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Error(err) | 		t.Error(err) | ||||||
| 	} | 	} | ||||||
| @ -473,8 +473,20 @@ func TestTransactionWithHooks(t *testing.T) { | |||||||
| 			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error | 			return tx3.Where("user_id", user.ID).Delete(&Account{}).Error | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Error(err) | 		t.Error(err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestTransactionWithDefaultTimeout(t *testing.T) { | ||||||
|  | 	db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to connect database, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tx := db.Begin() | ||||||
|  | 	time.Sleep(3 * time.Second) | ||||||
|  | 	if err = tx.Find(&User{}).Error; err == nil { | ||||||
|  | 		t.Errorf("should return error when transaction timeout, got error %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Chise1
						Chise1