Merge branch 'master' into test_migrate_column

This commit is contained in:
Jinzhu 2022-11-03 21:18:57 +08:00 committed by GitHub
commit 04dc915cb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 716 additions and 125 deletions

View File

@ -30,7 +30,7 @@ The fantastic ORM library for Golang, aims to be developer friendly.
## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io)
* GORM Gen [gorm/gen](https://github.com/go-gorm/gen#gormgen)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing

View File

@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}

View File

@ -117,13 +117,21 @@ func BuildQuerySQL(db *gorm.DB) {
} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
tableAliasName := relation.Name
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: tableAliasName + "__" + s,
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {

View File

@ -10,6 +10,7 @@ import (
)
// Model specify the model you would like to run db operations
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
@ -179,6 +180,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
}
// Joins specify Joins conditions
//
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
@ -187,10 +189,12 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
return
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
}
@ -219,6 +223,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
}
// Order specify order when retrieve records from database
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) {
@ -244,7 +249,7 @@ func (db *DB) Order(value interface{}) (tx *DB) {
// Limit specify the number of records to be retrieved
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: limit})
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
@ -256,6 +261,7 @@ func (db *DB) Offset(offset int) (tx *DB) {
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
@ -274,6 +280,7 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
}
// Preload preload associations with given conditions
//
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()

View File

@ -29,6 +29,7 @@ func BenchmarkSelect(b *testing.B) {
func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
limit10 := 10
for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clauses := []clause.Interface{
@ -43,7 +44,7 @@ func BenchmarkComplexSelect(b *testing.B) {
clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}),
}},
clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}},
clause.Limit{Limit: 10, Offset: 20},
clause.Limit{Limit: &limit10, Offset: 20},
clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}},
}

View File

@ -9,7 +9,7 @@ const (
RightJoin JoinType = "RIGHT"
)
// Join join clause for from
// Join clause for from
type Join struct {
Type JoinType
Table Table

101
clause/joins_test.go Normal file
View File

@ -0,0 +1,101 @@
package clause_test
import (
"sync"
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
)
func TestJoin(t *testing.T) {
results := []struct {
name string
join clause.Join
sql string
}{
{
name: "LEFT JOIN",
join: clause.Join{
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "RIGHT JOIN",
join: clause.Join{
Type: clause.RightJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "INNER JOIN",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "CROSS JOIN",
join: clause.Join{
Type: clause.CrossJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
},
sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`",
},
{
name: "USING",
join: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
sql: "INNER JOIN `user` USING (`id`)",
},
{
name: "Expression",
join: clause.Join{
// Invalid
Type: clause.LeftJoin,
Table: clause.Table{Name: "user"},
ON: clause.Where{
Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}},
},
// Valid
Expression: clause.Join{
Type: clause.InnerJoin,
Table: clause.Table{Name: "user"},
Using: []string{"id"},
},
},
sql: "INNER JOIN `user` USING (`id`)",
},
}
for _, result := range results {
t.Run(result.name, func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
result.join.Build(stmt)
if result.sql != stmt.SQL.String() {
t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String())
}
})
}
}

View File

@ -4,7 +4,7 @@ import "strconv"
// Limit limit clause
type Limit struct {
Limit int
Limit *int
Offset int
}
@ -15,12 +15,12 @@ func (limit Limit) Name() string {
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(limit.Limit))
builder.WriteString(strconv.Itoa(*limit.Limit))
}
if limit.Offset > 0 {
if limit.Limit > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if limit.Limit == 0 && v.Limit != 0 {
if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) {
limit.Limit = v.Limit
}

View File

@ -8,6 +8,10 @@ import (
)
func TestLimit(t *testing.T) {
limit0 := 0
limit10 := 10
limit50 := 50
limitNeg10 := -10
results := []struct {
Clauses []clause.Interface
Result string
@ -15,11 +19,15 @@ func TestLimit(t *testing.T) {
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{
Limit: 10,
Limit: &limit10,
Offset: 20,
}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}},
"SELECT * FROM `users` LIMIT 0", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}},
"SELECT * FROM `users` OFFSET 20", nil,
@ -29,23 +37,23 @@ func TestLimit(t *testing.T) {
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: 10}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}},
"SELECT * FROM `users` LIMIT 10 OFFSET 20", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}},
"SELECT * FROM `users` LIMIT 10 OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}},
"SELECT * FROM `users` LIMIT 10", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: -10}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}},
"SELECT * FROM `users` OFFSET 30", nil,
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: 10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: 50}},
[]clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}},
"SELECT * FROM `users` LIMIT 50 OFFSET 30", nil,
},
}

View File

@ -16,6 +16,11 @@ func (OnConflict) Name() string {
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
@ -32,11 +37,6 @@ func (onConflict OnConflict) Build(builder Builder) {
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
}
if onConflict.DoNothing {

View File

@ -185,7 +185,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
totalSize = limit.Limit
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize

View File

@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false
var (
alterColumn, isSameType bool
)
if !field.PrimaryKey {
// check type
var isSameType bool
if strings.HasPrefix(fullDataType, realDataType) {
isSameType = true
}
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
if !isSameType {
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
@ -424,13 +421,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
break
}
}
}
if !isSameType {
alterColumn = true
}
}
}
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
@ -452,6 +450,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true
}
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {

View File

@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() {
}
}
func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
for query, stmt := range db.Stmts {
delete(db.Stmts, query)
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = map[string](*Stmt){}
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {

View File

@ -403,33 +403,30 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys {
if f := foreignSchema.LookUpField(foreignKey); f != nil {
foreignFields = append(foreignFields, f)
} else {
f := foreignSchema.LookUpField(foreignKey)
if f == nil {
reguessOrErr()
return
}
foreignFields = append(foreignFields, f)
}
} else {
var primaryFields []*Field
var primarySchemaName = primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
@ -466,10 +463,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
}
if len(foreignFields) == 0 {
switch {
case len(foreignFields) == 0:
reguessOrErr()
return
} else if len(relation.primaryKeys) > 0 {
case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 {
@ -483,7 +481,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
return
}
}
} else if len(primaryFields) == 0 {
case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {

View File

@ -71,6 +71,10 @@ type Tabler interface {
TableName() string
}
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
@ -125,6 +129,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}

View File

@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err
}

View File

@ -52,6 +52,8 @@ type join struct {
Name string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
}
// StatementModifier statement modifier interface
@ -179,6 +181,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
@ -540,6 +546,7 @@ func (stmt *Statement) clone() *Statement {
}
// SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {

View File

@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
AssertEqual(t, len(orgs), 0)
}
}
type AssociationEmptyUser struct {
ID uint
Name string
Pets []AssociationEmptyPet
}
type AssociationEmptyPet struct {
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
}
func TestAssociationEmptyPrimaryKey(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
id := uint(100)
user := AssociationEmptyUser{
ID: id,
Name: "jinzhu",
Pets: []AssociationEmptyPet{
{AssociationEmptyUserID: &id, Name: "bar"},
{AssociationEmptyUserID: &id, Name: "foo"},
},
}
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
if err != nil {
t.Fatalf("Failed to create, got error: %v", err)
}
var result AssociationEmptyUser
err = DB.Preload("Pets").First(&result, &id).Error
if err != nil {
t.Fatalf("Failed to find, got error: %v", err)
}
AssertEqual(t, result, user)
}

View File

@ -116,7 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
}
}()
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true}))
if err != nil {
t.Fatalf("Should open db success, but got %v", err)
}

View File

@ -3,18 +3,16 @@ module gorm.io/gorm/tests
go 1.16
require (
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15 // indirect
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect
gorm.io/driver/mysql v1.3.6
gorm.io/driver/postgres v1.3.10
gorm.io/driver/sqlite v1.3.6
gorm.io/driver/sqlserver v1.3.2
gorm.io/gorm v1.23.10
golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
golang.org/x/text v0.4.0 // indirect
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.4.4
gorm.io/driver/sqlite v1.4.3
gorm.io/driver/sqlserver v1.4.1
gorm.io/gorm v1.24.0
)
replace gorm.io/gorm => ../

View File

@ -260,3 +260,46 @@ func TestJoinWithSameColumnName(t *testing.T) {
t.Fatalf("wrong pet name")
}
}
func TestJoinArgsWithDB(t *testing.T) {
user := *GetUser("joins-args-db", Config{Pets: 2})
DB.Save(&user)
// test where
var user1 User
onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"})
if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2")
// test where and omit
onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name")
var user2 User
if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID)
AssertEqual(t, user2.NamedPet.Name, "")
// test where and select
onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name")
var user3 User
if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
AssertEqual(t, user3.NamedPet.ID, 0)
AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2")
// test select
onQuery4 := DB.Select("ID")
var user4 User
if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil {
t.Fatalf("Failed to load with joins on, got error: %v", err)
}
if user4.NamedPet.ID == 0 {
t.Fatal("Pet ID can not be empty")
}
AssertEqual(t, user4.NamedPet.Name, "")
}

View File

@ -1,8 +1,10 @@
package tests_test
import (
"context"
"fmt"
"math/rand"
"os"
"reflect"
"strings"
"testing"
@ -13,6 +15,7 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests"
)
@ -75,6 +78,44 @@ func TestMigrate(t *testing.T) {
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
}
}
}
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Smallint int8
type MigrateInt struct {
Int8 Smallint
}
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
}
},
}
DB.Migrator().DropTable(&MigrateInt{})
// The first AutoMigrate to make table with field with correct type
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
// make new session to set custom logger tracer
session := DB.Session(&gorm.Session{Logger: tracer})
// The second AutoMigrate to catch an error
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
}
func TestAutoMigrateSelfReferential(t *testing.T) {
@ -403,7 +444,7 @@ func TestMigrateColumns(t *testing.T) {
t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
}
if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") {
t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v)
}
if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") {
t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
@ -853,7 +894,7 @@ func findColumnType(dest interface{}, columnName string) (
return
}
func TestInvalidCachedPlan(t *testing.T) {
func TestInvalidCachedPlanSimpleProtocol(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
@ -888,6 +929,101 @@ func TestInvalidCachedPlan(t *testing.T) {
}
}
func TestInvalidCachedPlanPrepareStmt(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{PrepareStmt: true})
if err != nil {
t.Errorf("Open err:%v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
db.Logger = db.Logger.LogMode(logger.Info)
} else if debug == "false" {
db.Logger = db.Logger.LogMode(logger.Silent)
}
type Object1 struct {
ID uint
}
type Object2 struct {
ID uint
Field1 int `gorm:"type:int8"`
}
type Object3 struct {
ID uint
Field1 int `gorm:"type:int4"`
}
type Object4 struct {
ID uint
Field2 int
}
db.Migrator().DropTable("objects")
err = db.Table("objects").AutoMigrate(&Object1{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Create(&Object1{}).Error
if err != nil {
t.Errorf("create err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object2{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object2{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AlterColumn
err = db.Table("objects").AutoMigrate(&Object3{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object3{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
// AddColumn
err = db.Table("objects").AutoMigrate(&Object4{})
if err != nil {
t.Errorf("AutoMigrate err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().RenameColumn(&Object4{}, "field2", "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
db.Table("objects").Migrator().DropColumn(&Object4{}, "field3")
if err != nil {
t.Errorf("RenameColumn err:%v", err)
}
err = db.Table("objects").Take(&Object4{}).Error
if err != nil {
t.Errorf("take err:%v", err)
}
}
func TestDifferentTypeWithoutDeclaredLength(t *testing.T) {
type DiffType struct {
ID uint

View File

@ -7,6 +7,8 @@ import (
"github.com/google/uuid"
"github.com/lib/pq"
"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
)
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
@ -110,6 +112,45 @@ func TestPostgres(t *testing.T) {
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
DB.Migrator().DropTable("log_usage")
if err := DB.Exec(`
CREATE TABLE public.log_usage (
log_id bigint NOT NULL
);
ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY (
SEQUENCE NAME public.log_usage_log_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1
);
`).Error; err != nil {
t.Fatalf("failed to create table, got error %v", err)
}
columns, err := DB.Migrator().ColumnTypes("log_usage")
if err != nil {
t.Fatalf("failed to get columns, got error %v", err)
}
hasLogID := false
for _, column := range columns {
if column.Name() == "log_id" {
hasLogID = true
autoIncrement, ok := column.AutoIncrement()
if !ok || !autoIncrement {
t.Fatalf("column log_id should be auto incrementment")
}
}
}
if !hasLogID {
t.Fatalf("failed to found column log_id")
}
}
type Post struct {
@ -148,3 +189,68 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) {
t.Errorf("Failed, got error: %v", err)
}
}
func TestPostgresOnConstraint(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Thing struct {
gorm.Model
SomeID string
OtherID string
Data string
}
DB.Migrator().DropTable(&Thing{})
DB.Migrator().CreateTable(&Thing{})
if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil {
t.Error(err)
}
thing := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something",
}
DB.Create(&thing)
thing2 := Thing{
SomeID: "1234",
OtherID: "1234",
Data: "something else",
}
result := DB.Clauses(clause.OnConflict{
OnConstraint: "some_id_other_id_unique",
UpdateAll: true,
}).Create(&thing2)
if result.Error != nil {
t.Errorf("creating second thing: %v", result.Error)
}
var things []Thing
if err := DB.Find(&things).Error; err != nil {
t.Errorf("Failed, got error: %v", err)
}
if len(things) > 1 {
t.Errorf("expected 1 thing got more")
}
}
type CompanyNew struct {
ID int
Name int
}
func TestAlterColumnDataType(t *testing.T) {
DB.AutoMigrate(Company{})
if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil {
t.Fatalf("failed to alter column from string to int, got error %v", err)
}
DB.AutoMigrate(Company{})
}

View File

@ -2,8 +2,8 @@ package tests_test
import (
"context"
"sync"
"errors"
"sync"
"testing"
"time"
@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) {
t.Errorf("Failed, got error: %v", err)
}
}
func TestPreparedStmtReset(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}

View File

@ -18,6 +18,8 @@ type SerializerStruct struct {
gorm.Model
Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"`
Roles2 *Roles `gorm:"serializer:json"`
Roles3 *Roles `gorm:"serializer:json;not null"`
Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
}
var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil {
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}

View File

@ -367,7 +367,7 @@ func TestToSQL(t *testing.T) {
t.Skip("Skip SQL Server for this test, because it too difference with other dialects.")
}
date, _ := time.Parse("2006-01-02", "2021-10-18")
date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local)
// find
sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
}
// UpdateColumns
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
})
})
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
}
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.

View File

@ -5,6 +5,8 @@ import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"
. "gorm.io/gorm/utils/tests"
)
@ -145,3 +147,27 @@ func TestTableWithAllFields(t *testing.T) {
AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3})
}
type UserWithTableNamer struct {
gorm.Model
Name string
}
func (UserWithTableNamer) TableName(namer schema.Namer) string {
return namer.TableName("user")
}
func TestTableWithNamer(t *testing.T) {
var db, _ = gorm.Open(tests.DummyDialector{}, &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: "t_",
}})
sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{})
})
if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) {
t.Errorf("Table with namer, got %v", sql)
}
}

34
tests/tracer_test.go Normal file
View File

@ -0,0 +1,34 @@
package tests_test
import (
"context"
"time"
"gorm.io/gorm/logger"
)
type Tracer struct {
Logger logger.Interface
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
return S.Logger.LogMode(level)
}
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
S.Logger.Info(ctx, s, i...)
}
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
S.Logger.Warn(ctx, s, i...)
}
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
S.Logger.Error(ctx, s, i...)
}
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
S.Logger.Trace(ctx, begin, fc, err)
S.Test(ctx, begin, fc, err)
}

View File

@ -2,6 +2,7 @@ package tests
import (
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
@ -13,7 +14,14 @@ func (DummyDialector) Name() string {
return "dummy"
}
func (DummyDialector) Initialize(*gorm.DB) error {
func (DummyDialector) Initialize(db *gorm.DB) error {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
return nil
}

View File

@ -16,7 +16,7 @@ var gormSourceDir string
func init() {
_, file, _, _ := runtime.Caller(0)
// compatible solution to get gorm source directory with various operating systems
gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "")
gormSourceDir = regexp.MustCompile(`gorm.utils.utils\.go`).ReplaceAllString(file, "")
}
// FileWithLineNum return the file name and line number of the current file