Merge branch 'master' into test_migrate_column
This commit is contained in:
commit
04dc915cb7
@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
identityMap[cacheKey] = true
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
identityMap[cacheKey] = true
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
|
@ -117,12 +117,20 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
|
||||
tableAliasName := relation.Name
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
})
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: tableAliasName + "__" + s,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
@ -179,18 +180,21 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
|
||||
return
|
||||
j.On = &where
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,8 +223,9 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
}
|
||||
|
||||
// Order specify order when retrieve records from database
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
@ -244,7 +249,7 @@ func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
// Limit specify the number of records to be retrieved
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: limit})
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
return
|
||||
}
|
||||
|
||||
@ -256,17 +261,18 @@ func (db *DB) Offset(offset int) (tx *DB) {
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
@ -274,7 +280,8 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
//
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
|
@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
|
||||
func BenchmarkComplexSelect(b *testing.B) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
|
||||
limit10 := 10
|
||||
for i := 0; i < b.N; i++ {
|
||||
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
clauses := []clause.Interface{
|
||||
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
|
||||
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
|
||||
}},
|
||||
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
|
||||
clause.Limit{Limit: 10, Offset: 20},
|
||||
clause.Limit{Limit: &limit10, Offset: 20},
|
||||
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ const (
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
// Join join clause for from
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
|
101
clause/joins_test.go
Normal file
101
clause/joins_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package clause_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
results := []struct {
|
||||
name string
|
||||
join clause.Join
|
||||
sql string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.RightJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "CROSS JOIN",
|
||||
join: clause.Join{
|
||||
Type: clause.CrossJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
},
|
||||
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
|
||||
},
|
||||
{
|
||||
name: "USING",
|
||||
join: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
{
|
||||
name: "Expression",
|
||||
join: clause.Join{
|
||||
// Invalid
|
||||
Type: clause.LeftJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
ON: clause.Where{
|
||||
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
|
||||
},
|
||||
// Valid
|
||||
Expression: clause.Join{
|
||||
Type: clause.InnerJoin,
|
||||
Table: clause.Table{Name: "user"},
|
||||
Using: []string{"id"},
|
||||
},
|
||||
},
|
||||
sql: "INNER JOIN `user` USING (`id`)",
|
||||
},
|
||||
}
|
||||
for _, result := range results {
|
||||
t.Run(result.name, func(t *testing.T) {
|
||||
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
|
||||
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
|
||||
result.join.Build(stmt)
|
||||
if result.sql != stmt.SQL.String() {
|
||||
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -4,7 +4,7 @@ import "strconv"
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit int
|
||||
Limit *int
|
||||
Offset int
|
||||
}
|
||||
|
||||
@ -15,12 +15,12 @@ func (limit Limit) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(limit.Limit))
|
||||
builder.WriteString(strconv.Itoa(*limit.Limit))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if limit.Limit == 0 && v.Limit != 0 {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,10 @@ import (
|
||||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
limit0 := 0
|
||||
limit10 := 10
|
||||
limit50 := 50
|
||||
limitNeg10 := -10
|
||||
results := []struct {
|
||||
Clauses []clause.Interface
|
||||
Result string
|
||||
@ -15,11 +19,15 @@ func TestLimit(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
|
||||
Limit: 10,
|
||||
Limit: &limit10,
|
||||
Offset: 20,
|
||||
}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
|
||||
"SELECT * FROM `users` LIMIT 0", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
|
||||
"SELECT * FROM `users` OFFSET 20", nil,
|
||||
@ -29,23 +37,23 @@ func TestLimit(t *testing.T) {
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
|
||||
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
|
||||
"SELECT * FROM `users` LIMIT 10", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
|
||||
"SELECT * FROM `users` OFFSET 30", nil,
|
||||
},
|
||||
{
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
|
||||
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
|
||||
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
|
||||
},
|
||||
}
|
||||
|
@ -16,27 +16,27 @@ func (OnConflict) Name() string {
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
|
@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
totalSize = limit.Limit
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
|
@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
var (
|
||||
alterColumn, isSameType bool
|
||||
)
|
||||
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
var isSameType bool
|
||||
if strings.HasPrefix(fullDataType, realDataType) {
|
||||
isSameType = true
|
||||
}
|
||||
|
||||
// check type aliases
|
||||
if !isSameType {
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
@ -424,32 +421,34 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
for query, stmt := range db.Stmts {
|
||||
delete(db.Stmts, query)
|
||||
go stmt.Close()
|
||||
}
|
||||
|
||||
db.PreparedSQL = make([]string, 0, 100)
|
||||
db.Stmts = map[string](*Stmt){}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
|
@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema != nil {
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
} else {
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
if f := foreignSchema.LookUpField(foreignKey); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
} else {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
var primaryFields []*Field
|
||||
var primarySchemaName = primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
}
|
||||
}
|
||||
|
||||
if len(foreignFields) == 0 {
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
reguessOrErr()
|
||||
return
|
||||
} else if len(relation.primaryKeys) > 0 {
|
||||
case len(relation.primaryKeys) > 0:
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if len(primaryFields) == 0 {
|
||||
case len(primaryFields) == 0:
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
|
@ -71,6 +71,10 @@ type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
|
@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
|
17
statement.go
17
statement.go
@ -49,9 +49,11 @@ type Statement struct {
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
@ -179,6 +181,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
@ -540,8 +546,9 @@ func (stmt *Statement) clone() *Statement {
|
||||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
|
@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||
AssertEqual(t, len(orgs), 0)
|
||||
}
|
||||
}
|
||||
|
||||
type AssociationEmptyUser struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pets []AssociationEmptyPet
|
||||
}
|
||||
|
||||
type AssociationEmptyPet struct {
|
||||
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||
}
|
||||
|
||||
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
|
||||
id := uint(100)
|
||||
user := AssociationEmptyUser{
|
||||
ID: id,
|
||||
Name: "jinzhu",
|
||||
Pets: []AssociationEmptyPet{
|
||||
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create, got error: %v", err)
|
||||
}
|
||||
|
||||
var result AssociationEmptyUser
|
||||
err = DB.Preload("Pets").First(&result, &id).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
|
||||
if err != nil {
|
||||
t.Fatalf("Should open db success, but got %v", err)
|
||||
}
|
||||
|
16
tests/go.mod
16
tests/go.mod
@ -3,18 +3,16 @@ module gorm.io/gorm/tests
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.7
|
||||
github.com/mattn/go-sqlite3 v1.14.15 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect
|
||||
gorm.io/driver/mysql v1.3.6
|
||||
gorm.io/driver/postgres v1.3.10
|
||||
gorm.io/driver/sqlite v1.3.6
|
||||
gorm.io/driver/sqlserver v1.3.2
|
||||
gorm.io/gorm v1.23.10
|
||||
golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
|
||||
golang.org/x/text v0.4.0 // indirect
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.4.4
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
gorm.io/driver/sqlserver v1.4.1
|
||||
gorm.io/gorm v1.24.0
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) {
|
||||
t.Fatalf("wrong pet name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinArgsWithDB(t *testing.T) {
|
||||
user := *GetUser("joins-args-db", Config{Pets: 2})
|
||||
DB.Save(&user)
|
||||
|
||||
// test where
|
||||
var user1 User
|
||||
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
|
||||
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test where and omit
|
||||
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
|
||||
var user2 User
|
||||
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
|
||||
AssertEqual(t, user2.NamedPet.Name, "")
|
||||
|
||||
// test where and select
|
||||
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
|
||||
var user3 User
|
||||
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
AssertEqual(t, user3.NamedPet.ID, 0)
|
||||
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
|
||||
|
||||
// test select
|
||||
onQuery4 := DB.Select("ID")
|
||||
var user4 User
|
||||
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
|
||||
t.Fatalf("Failed to load with joins on, got error: %v", err)
|
||||
}
|
||||
if user4.NamedPet.ID == 0 {
|
||||
t.Fatal("Pet ID can not be empty")
|
||||
}
|
||||
AssertEqual(t, user4.NamedPet.Name, "")
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -13,6 +15,7 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
@ -75,6 +78,44 @@ func TestMigrate(t *testing.T) {
|
||||
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAutoMigrateInt8PG(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
type Smallint int8
|
||||
|
||||
type MigrateInt struct {
|
||||
Int8 Smallint
|
||||
}
|
||||
|
||||
tracer := Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&MigrateInt{})
|
||||
|
||||
// The first AutoMigrate to make table with field with correct type
|
||||
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
|
||||
// make new session to set custom logger tracer
|
||||
session := DB.Session(&gorm.Session{Logger: tracer})
|
||||
|
||||
// The second AutoMigrate to catch an error
|
||||
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateSelfReferential(t *testing.T) {
|
||||
@ -403,7 +444,7 @@ func TestMigrateColumns(t *testing.T) {
|
||||
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
}
|
||||
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
|
||||
}
|
||||
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
|
||||
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
|
||||
@ -853,7 +894,7 @@ func findColumnType(dest interface{}, columnName string) (
|
||||
return
|
||||
}
|
||||
|
||||
func TestInvalidCachedPlan(t *testing.T) {
|
||||
func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
@ -888,6 +929,101 @@ func TestInvalidCachedPlan(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
|
||||
if err != nil {
|
||||
t.Errorf("Open err:%v", err)
|
||||
}
|
||||
if debug := os.Getenv("DEBUG"); debug == "true" {
|
||||
db.Logger = db.Logger.LogMode(logger.Info)
|
||||
} else if debug == "false" {
|
||||
db.Logger = db.Logger.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
type Object1 struct {
|
||||
ID uint
|
||||
}
|
||||
type Object2 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int8"`
|
||||
}
|
||||
type Object3 struct {
|
||||
ID uint
|
||||
Field1 int `gorm:"type:int4"`
|
||||
}
|
||||
type Object4 struct {
|
||||
ID uint
|
||||
Field2 int
|
||||
}
|
||||
db.Migrator().DropTable("objects")
|
||||
|
||||
err = db.Table("objects").AutoMigrate(&Object1{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
err = db.Table("objects").Create(&Object1{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("create err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object2{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object2{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AlterColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object3{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object3{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
// AddColumn
|
||||
err = db.Table("objects").AutoMigrate(&Object4{})
|
||||
if err != nil {
|
||||
t.Errorf("AutoMigrate err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
|
||||
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
|
||||
if err != nil {
|
||||
t.Errorf("RenameColumn err:%v", err)
|
||||
}
|
||||
|
||||
err = db.Table("objects").Take(&Object4{}).Error
|
||||
if err != nil {
|
||||
t.Errorf("take err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
|
||||
type DiffType struct {
|
||||
ID uint
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
|
||||
@ -110,6 +112,45 @@ func TestPostgres(t *testing.T) {
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable("log_usage")
|
||||
|
||||
if err := DB.Exec(`
|
||||
CREATE TABLE public.log_usage (
|
||||
log_id bigint NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
|
||||
SEQUENCE NAME public.log_usage_log_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1
|
||||
);
|
||||
`).Error; err != nil {
|
||||
t.Fatalf("failed to create table, got error %v", err)
|
||||
}
|
||||
|
||||
columns, err := DB.Migrator().ColumnTypes("log_usage")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get columns, got error %v", err)
|
||||
}
|
||||
|
||||
hasLogID := false
|
||||
for _, column := range columns {
|
||||
if column.Name() == "log_id" {
|
||||
hasLogID = true
|
||||
autoIncrement, ok := column.AutoIncrement()
|
||||
if !ok || !autoIncrement {
|
||||
t.Fatalf("column log_id should be auto incrementment")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasLogID {
|
||||
t.Fatalf("failed to found column log_id")
|
||||
}
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
@ -148,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresOnConstraint(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
type Thing struct {
|
||||
gorm.Model
|
||||
SomeID string
|
||||
OtherID string
|
||||
Data string
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Thing{})
|
||||
DB.Migrator().CreateTable(&Thing{})
|
||||
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
thing := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something",
|
||||
}
|
||||
|
||||
DB.Create(&thing)
|
||||
|
||||
thing2 := Thing{
|
||||
SomeID: "1234",
|
||||
OtherID: "1234",
|
||||
Data: "something else",
|
||||
}
|
||||
|
||||
result := DB.Clauses(clause.OnConflict{
|
||||
OnConstraint: "some_id_other_id_unique",
|
||||
UpdateAll: true,
|
||||
}).Create(&thing2)
|
||||
if result.Error != nil {
|
||||
t.Errorf("creating second thing: %v", result.Error)
|
||||
}
|
||||
|
||||
var things []Thing
|
||||
if err := DB.Find(&things).Error; err != nil {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(things) > 1 {
|
||||
t.Errorf("expected 1 thing got more")
|
||||
}
|
||||
}
|
||||
|
||||
type CompanyNew struct {
|
||||
ID int
|
||||
Name int
|
||||
}
|
||||
|
||||
func TestAlterColumnDataType(t *testing.T) {
|
||||
DB.AutoMigrate(Company{})
|
||||
|
||||
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
|
||||
t.Fatalf("failed to alter column from string to int, got error %v", err)
|
||||
}
|
||||
|
||||
DB.AutoMigrate(Company{})
|
||||
}
|
||||
|
@ -2,8 +2,8 @@ package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,8 @@ type SerializerStruct struct {
|
||||
gorm.Model
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Roles2 *Roles `gorm:"serializer:json"`
|
||||
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
|
||||
}
|
||||
|
||||
var result SerializerStruct
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) {
|
||||
t.Skip("Skip SQL Server for this test, because it too difference with other dialects.")
|
||||
}
|
||||
|
||||
date, _ := time.Parse("2006-01-02", "2021-10-18")
|
||||
date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local)
|
||||
|
||||
// find
|
||||
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
|
||||
if DB.Statement.DryRun || DB.DryRun {
|
||||
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
||||
}
|
||||
|
||||
// UpdateColumns
|
||||
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
|
||||
})
|
||||
})
|
||||
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
|
||||
}
|
||||
|
||||
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) {
|
||||
|
||||
AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3})
|
||||
}
|
||||
|
||||
type UserWithTableNamer struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
|
||||
func (UserWithTableNamer) TableName(namer schema.Namer) string {
|
||||
return namer.TableName("user")
|
||||
}
|
||||
|
||||
func TestTableWithNamer(t *testing.T) {
|
||||
var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
TablePrefix: "t_",
|
||||
}})
|
||||
|
||||
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})
|
||||
})
|
||||
|
||||
if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) {
|
||||
t.Errorf("Table with namer, got %v", sql)
|
||||
}
|
||||
}
|
||||
|
34
tests/tracer_test.go
Normal file
34
tests/tracer_test.go
Normal file
@ -0,0 +1,34 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
Logger logger.Interface
|
||||
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
|
||||
}
|
||||
|
||||
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
|
||||
return S.Logger.LogMode(level)
|
||||
}
|
||||
|
||||
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Info(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Warn(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Error(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
S.Logger.Trace(ctx, begin, fc, err)
|
||||
S.Test(ctx, begin, fc, err)
|
||||
}
|
@ -2,6 +2,7 @@ package tests
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
@ -13,7 +14,14 @@ func (DummyDialector) Name() string {
|
||||
return "dummy"
|
||||
}
|
||||
|
||||
func (DummyDialector) Initialize(*gorm.DB) error {
|
||||
func (DummyDialector) Initialize(db *gorm.DB) error {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ var gormSourceDir string
|
||||
func init() {
|
||||
_, file, _, _ := runtime.Caller(0)
|
||||
// compatible solution to get gorm source directory with various operating systems
|
||||
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "")
|
||||
gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "")
|
||||
}
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
|
Loading…
x
Reference in New Issue
Block a user