Merge branch 'master' into test_migrate_column

This commit is contained in:
Jinzhu 2022-11-03 21:18:57 +08:00 committed by GitHub
commit 04dc915cb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 716 additions and 125 deletions

View File

@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started ## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io) * GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen) * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing ## Contributing

View File

@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem) distinctElems = reflect.Append(distinctElems, elem)
} }

View File

@ -117,13 +117,21 @@ func BuildQuerySQL(db *gorm.DB) {
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name tableAliasName := relation.Name
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames { for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName, Table: tableAliasName,
Name: s, Name: s,
Alias: tableAliasName + "__" + s, Alias: tableAliasName + "__" + s,
}) })
} }
}
exprs := make([]clause.Expression, len(relation.References)) exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References { for idx, ref := range relation.References {

View File

@ -10,6 +10,7 @@ import (
) )
// Model specify the model you would like to run db operations // Model specify the model you would like to run db operations
//
// // update all users's name to `hello` // // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello") // db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` // // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
@ -179,6 +180,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
} }
// Joins specify Joins conditions // Joins specify Joins conditions
//
// db.Joins("Account").Find(&user) // db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
@ -187,10 +189,12 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
if len(args) == 1 { if len(args) == 1 {
if db, ok := args[0].(*DB); ok { if db, ok := args[0].(*DB); ok {
j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where}) j.On = &where
return
} }
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
} }
} }
@ -219,6 +223,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
} }
// Order specify order when retrieve records from database // Order specify order when retrieve records from database
//
// db.Order("name DESC") // db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) { func (db *DB) Order(value interface{}) (tx *DB) {
@ -244,7 +249,7 @@ func (db *DB) Order(value interface{}) (tx *DB) {
// Limit specify the number of records to be retrieved // Limit specify the number of records to be retrieved
func (db *DB) Limit(limit int) (tx *DB) { func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit}) tx.Statement.AddClause(clause.Limit{Limit: &limit})
return return
} }
@ -256,6 +261,7 @@ func (db *DB) Offset(offset int) (tx *DB) {
} }
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000) // return db.Where("amount > ?", 1000)
// } // }
@ -274,6 +280,7 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
} }
// Preload preload associations with given conditions // Preload preload associations with given conditions
//
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()

View File

@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) { func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{ clauses := []clause.Interface{
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}}, }},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
} }

View File

@ -9,7 +9,7 @@ const (
RightJoin JoinType = "RIGHT" RightJoin JoinType = "RIGHT"
) )
// Join join clause for from // Join clause for from
type Join struct { type Join struct {
Type JoinType Type JoinType
Table Table Table Table

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -4,7 +4,7 @@ import "strconv"
// Limit limit clause // Limit limit clause
type Limit struct { type Limit struct {
Limit int Limit *int
Offset int Offset int
} }
@ -15,12 +15,12 @@ func (limit Limit) Name() string {
// Build build where clause // Build build where clause
func (limit Limit) Build(builder Builder) { func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ") builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit)) builder.WriteString(strconv.Itoa(*limit.Limit))
} }
if limit.Offset > 0 { if limit.Offset > 0 {
if limit.Limit > 0 { if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ') builder.WriteByte(' ')
} }
builder.WriteString("OFFSET ") builder.WriteString("OFFSET ")
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = "" clause.Name = ""
if v, ok := clause.Expression.(Limit); ok { if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 { if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
limit.Limit = v.Limit limit.Limit = v.Limit
} }

View File

@ -8,6 +8,10 @@ import (
) )
func TestLimit(t *testing.T) { func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct { results := []struct {
Clauses []clause.Interface Clauses []clause.Interface
Result string Result string
@ -15,11 +19,15 @@ func TestLimit(t *testing.T) {
}{ }{
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10, Limit: &limit10,
Offset: 20, Offset: 20,
}}, }},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
}, },
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil, "SELECT * FROM `users` OFFSET 20", nil,
@ -29,23 +37,23 @@ func TestLimit(t *testing.T) {
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil, "SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil, "SELECT * FROM `users` LIMIT 10", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET 30", nil, "SELECT * FROM `users` OFFSET 30", nil,
}, },
{ {
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}}, []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil, "SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
}, },
} }

View File

@ -16,6 +16,11 @@ func (OnConflict) Name() string {
// Build build onConflict clause // Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) { func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 { if len(onConflict.Columns) > 0 {
builder.WriteByte('(') builder.WriteByte('(')
for idx, column := range onConflict.Columns { for idx, column := range onConflict.Columns {
@ -32,11 +37,6 @@ func (onConflict OnConflict) Build(builder Builder) {
onConflict.TargetWhere.Build(builder) onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ') builder.WriteByte(' ')
} }
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} }
if onConflict.DoNothing { if onConflict.DoNothing {

View File

@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
var totalSize int var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok { if limit, ok := c.Expression.(clause.Limit); ok {
totalSize = limit.Limit if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize { if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize batchSize = totalSize

View File

@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false var (
alterColumn, isSameType bool
)
if !field.PrimaryKey { if !field.PrimaryKey {
// check type // check type
var isSameType bool if !strings.HasPrefix(fullDataType, realDataType) {
if strings.HasPrefix(fullDataType, realDataType) {
isSameType = true
}
// check type aliases // check type aliases
if !isSameType {
aliases := m.DB.Migrator().GetTypeAliases(realDataType) aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases { for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) { if strings.HasPrefix(fullDataType, alias) {
@ -424,13 +421,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
break break
} }
} }
}
if !isSameType { if !isSameType {
alterColumn = true alterColumn = true
} }
} }
}
if !isSameType {
// check size // check size
if length, ok := columnType.Length(); length != int64(field.Size) { if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 { if length > 0 && field.Size > 0 {
@ -452,6 +450,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true alterColumn = true
} }
} }
}
// check nullable // check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {

View File

@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() {
} }
} }
func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
for query, stmt := range db.Stmts {
delete(db.Stmts, query)
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = map[string](*Stmt){}
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {

View File

@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs: case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs: case guessEmbeddedBelongs:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas: case guessHas:
case guessEmbeddedHas: case guessEmbeddedHas:
if field.OwnerSchema != nil { if field.OwnerSchema == nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
reguessOrErr() reguessOrErr()
return return
} }
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} }
if len(relation.foreignKeys) > 0 { if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys { for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil { f := foreignSchema.LookUpField(foreignKey)
foreignFields = append(foreignFields, f) if f == nil {
} else {
reguessOrErr() reguessOrErr()
return return
} }
foreignFields = append(foreignFields, f)
} }
} else { } else {
var primaryFields []*Field
var primarySchemaName = primarySchema.Name var primarySchemaName = primarySchema.Name
if primarySchemaName == "" { if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name primarySchemaName = relation.FieldSchema.Name
@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
} }
} }
if len(foreignFields) == 0 { switch {
case len(foreignFields) == 0:
reguessOrErr() reguessOrErr()
return return
} else if len(relation.primaryKeys) > 0 { case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys { for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil { if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 { if len(primaryFields) < idx+1 {
@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return return
} }
} }
} else if len(primaryFields) == 0 { case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) { } else if len(primarySchema.PrimaryFields) == len(foreignFields) {

View File

@ -71,6 +71,10 @@ type Tabler interface {
TableName() string TableName() string
} }
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "") return ParseWithSpecialTableName(dest, cacheStore, namer, "")
@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }

View File

@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
// Value implements serializer interface // Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue) result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err return string(result), err
} }

View File

@ -52,6 +52,8 @@ type join struct {
Name string Name string
Conds []interface{} Conds []interface{}
On *clause.Where On *clause.Where
Selects []string
Omits []string
} }
// StatementModifier statement modifier interface // StatementModifier statement modifier interface
@ -179,6 +181,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
} }
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression: case clause.Expression:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:
@ -540,6 +546,7 @@ func (stmt *Statement) clone() *Statement {
} }
// SetColumn set column's value // SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {

View File

@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
AssertEqual(t, len(orgs), 0) AssertEqual(t, len(orgs), 0)
} }
} }
type AssociationEmptyUser struct {
ID uint
Name string
Pets []AssociationEmptyPet
}
type AssociationEmptyPet struct {
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
}
func TestAssociationEmptyPrimaryKey(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
id := uint(100)
user := AssociationEmptyUser{
ID: id,
Name: "jinzhu",
Pets: []AssociationEmptyPet{
{AssociationEmptyUserID: &id, Name: "bar"},
{AssociationEmptyUserID: &id, Name: "foo"},
},
}
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
if err != nil {
t.Fatalf("Failed to create, got error: %v", err)
}
var result AssociationEmptyUser
err = DB.Preload("Pets").First(&result, &id).Error
if err != nil {
t.Fatalf("Failed to find, got error: %v", err)
}
AssertEqual(t, result, user)
}

View File

@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
} }
}() }()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
if err != nil { if err != nil {
t.Fatalf("Should open db success, but got %v", err) t.Fatalf("Should open db success, but got %v", err)
} }

View File

@ -3,18 +3,16 @@ module gorm.io/gorm/tests
go 1.16 go 1.16
require ( require (
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.7 github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15 // indirect golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect golang.org/x/text v0.4.0 // indirect
gorm.io/driver/mysql v1.3.6 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.3.10 gorm.io/driver/postgres v1.4.4
gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlite v1.4.3
gorm.io/driver/sqlserver v1.3.2 gorm.io/driver/sqlserver v1.4.1
gorm.io/gorm v1.23.10 gorm.io/gorm v1.24.0
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../

View File

@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) {
t.Fatalf("wrong pet name") t.Fatalf("wrong pet name")
} }
} }
func TestJoinArgsWithDB(t *testing.T) {
user := *GetUser("joins-args-db", Config{Pets: 2})
DB.Save(&user)
// test where
var user1 User
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
// test where and omit
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
var user2 User
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
AssertEqual(t, user2.NamedPet.Name, "")
// test where and select
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
var user3 User
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user3.NamedPet.ID, 0)
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
// test select
onQuery4 := DB.Select("ID")
var user4 User
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
if user4.NamedPet.ID == 0 {
t.Fatal("Pet ID can not be empty")
}
AssertEqual(t, user4.NamedPet.Name, "")
}

View File

@ -1,8 +1,10 @@
package tests_test package tests_test
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -13,6 +15,7 @@ import (
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver" "gorm.io/driver/sqlserver"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -75,6 +78,44 @@ func TestMigrate(t *testing.T) {
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
} }
} }
}
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Smallint int8
type MigrateInt struct {
Int8 Smallint
}
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
}
},
}
DB.Migrator().DropTable(&MigrateInt{})
// The first AutoMigrate to make table with field with correct type
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
// make new session to set custom logger tracer
session := DB.Session(&gorm.Session{Logger: tracer})
// The second AutoMigrate to catch an error
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
} }
func TestAutoMigrateSelfReferential(t *testing.T) { func TestAutoMigrateSelfReferential(t *testing.T) {
@ -403,7 +444,7 @@ func TestMigrateColumns(t *testing.T) {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
} }
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
} }
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
@ -853,7 +894,7 @@ func findColumnType(dest interface{}, columnName string) (
return return
} }
func TestInvalidCachedPlan(t *testing.T) { func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
if DB.Dialector.Name() != "postgres" { if DB.Dialector.Name() != "postgres" {
return return
} }
@ -888,6 +929,101 @@ func TestInvalidCachedPlan(t *testing.T) {
} }
} }
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
if err != nil {
t.Errorf("Open err:%v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" {
db.Logger = db.Logger.LogMode(logger.Silent)
}
type Object1 struct {
ID uint
}
type Object2 struct {
ID uint
Field1 int `gorm:"type:int8"`
}
type Object3 struct {
ID uint
Field1 int `gorm:"type:int4"`
}
type Object4 struct {
ID uint
Field2 int
}
db.Migrator().DropTable("objects")
err = db.Table("objects").AutoMigrate(&Object1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Create(&Object1{}).Error
if err != nil {
t.Errorf("create err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object2{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AlterColumn
err = db.Table("objects").AutoMigrate(&Object3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object3{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object4{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
}
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
type DiffType struct { type DiffType struct {
ID uint ID uint

View File

@ -7,6 +7,8 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq" "github.com/lib/pq"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
) )
func TestPostgresReturningIDWhichHasStringType(t *testing.T) { func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
@ -110,6 +112,45 @@ func TestPostgres(t *testing.T) {
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err) t.Errorf("No error should happen, but got %v", err)
} }
DB.Migrator().DropTable("log_usage")
if err := DB.Exec(`
CREATE TABLE public.log_usage (
log_id bigint NOT NULL
);
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
SEQUENCE NAME public.log_usage_log_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1
);
`).Error; err != nil {
t.Fatalf("failed to create table, got error %v", err)
}
columns, err := DB.Migrator().ColumnTypes("log_usage")
if err != nil {
t.Fatalf("failed to get columns, got error %v", err)
}
hasLogID := false
for _, column := range columns {
if column.Name() == "log_id" {
hasLogID = true
autoIncrement, ok := column.AutoIncrement()
if !ok || !autoIncrement {
t.Fatalf("column log_id should be auto incrementment")
}
}
}
if !hasLogID {
t.Fatalf("failed to found column log_id")
}
} }
type Post struct { type Post struct {
@ -148,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
t.Errorf("Failed, got error: %v", err) t.Errorf("Failed, got error: %v", err)
} }
} }
func TestPostgresOnConstraint(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Thing struct {
gorm.Model
SomeID string
OtherID string
Data string
}
DB.Migrator().DropTable(&Thing{})
DB.Migrator().CreateTable(&Thing{})
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
t.Error(err)
}
thing := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something",
}
DB.Create(&thing)
thing2 := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something else",
}
result := DB.Clauses(clause.OnConflict{
OnConstraint: "some_id_other_id_unique",
UpdateAll: true,
}).Create(&thing2)
if result.Error != nil {
t.Errorf("creating second thing: %v", result.Error)
}
var things []Thing
if err := DB.Find(&things).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
if len(things) > 1 {
t.Errorf("expected 1 thing got more")
}
}
type CompanyNew struct {
ID int
Name int
}
func TestAlterColumnDataType(t *testing.T) {
DB.AutoMigrate(Company{})
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
t.Fatalf("failed to alter column from string to int, got error %v", err)
}
DB.AutoMigrate(Company{})
}

View File

@ -2,8 +2,8 @@ package tests_test
import ( import (
"context" "context"
"sync"
"errors" "errors"
"sync"
"testing" "testing"
"time" "time"
@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) {
t.Errorf("Failed, got error: %v", err) t.Errorf("Failed, got error: %v", err)
} }
} }
func TestPreparedStmtReset(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}

View File

@ -18,6 +18,8 @@ type SerializerStruct struct {
gorm.Model gorm.Model
Name []byte `gorm:"json"` Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"` Roles Roles `gorm:"serializer:json"`
Roles2 *Roles `gorm:"serializer:json"`
Roles3 *Roles `gorm:"serializer:json;not null"`
Contracts map[string]interface{} `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"` JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
} }
var result SerializerStruct var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil { if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err) t.Fatalf("failed to query data, got error %v", err)
} }

View File

@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) {
t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") t.Skip("Skip SQL Server for this test, because it too difference with other dialects.")
} }
date, _ := time.Parse("2006-01-02", "2021-10-18") date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local)
// find // find
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
if DB.Statement.DryRun || DB.DryRun { if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
} }
// UpdateColumns
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
})
})
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
} }
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) {
AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3})
} }
type UserWithTableNamer struct {
gorm.Model
Name string
}
func (UserWithTableNamer) TableName(namer schema.Namer) string {
return namer.TableName("user")
}
func TestTableWithNamer(t *testing.T) {
var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: "t_",
}})
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})
})
if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) {
t.Errorf("Table with namer, got %v", sql)
}
}

34
tests/tracer_test.go Normal file
View File

@ -0,0 +1,34 @@
package tests_test
import (
"context"
"time"
"gorm.io/gorm/logger"
)
type Tracer struct {
Logger logger.Interface
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
return S.Logger.LogMode(level)
}
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
S.Logger.Info(ctx, s, i...)
}
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
S.Logger.Warn(ctx, s, i...)
}
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
S.Logger.Error(ctx, s, i...)
}
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
S.Logger.Trace(ctx, begin, fc, err)
S.Test(ctx, begin, fc, err)
}

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
@ -13,7 +14,14 @@ func (DummyDialector) Name() string {
return "dummy" return "dummy"
} }
func (DummyDialector) Initialize(*gorm.DB) error { func (DummyDialector) Initialize(db *gorm.DB) error {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
return nil return nil
} }

View File

@ -16,7 +16,7 @@ var gormSourceDir string
func init() { func init() {
_, file, _, _ := runtime.Caller(0) _, file, _, _ := runtime.Caller(0)
// compatible solution to get gorm source directory with various operating systems // compatible solution to get gorm source directory with various operating systems
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "")
} }
// FileWithLineNum return the file name and line number of the current file // FileWithLineNum return the file name and line number of the current file