Merge remote-tracking branch 'upstream/master' into test_invalid_cache_preprare
This commit is contained in:
commit
4c5d727449
@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
|
@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
alterColumn := false
|
||||
var (
|
||||
alterColumn, isSameType bool
|
||||
)
|
||||
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
var isSameType bool
|
||||
if strings.HasPrefix(fullDataType, realDataType) {
|
||||
isSameType = true
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
if !isSameType {
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
@ -424,13 +421,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
@ -452,6 +450,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
|
@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Reset() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
for query, stmt := range db.Stmts {
|
||||
delete(db.Stmts, query)
|
||||
go stmt.Close()
|
||||
}
|
||||
|
||||
db.PreparedSQL = make([]string, 0, 100)
|
||||
db.Stmts = map[string](*Stmt){}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
|
@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
|
@ -179,6 +179,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
|
@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||
AssertEqual(t, len(orgs), 0)
|
||||
}
|
||||
}
|
||||
|
||||
type AssociationEmptyUser struct {
|
||||
ID uint
|
||||
Name string
|
||||
Pets []AssociationEmptyPet
|
||||
}
|
||||
|
||||
type AssociationEmptyPet struct {
|
||||
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
|
||||
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
|
||||
}
|
||||
|
||||
func TestAssociationEmptyPrimaryKey(t *testing.T) {
|
||||
if DB.Dialector.Name() != "mysql" {
|
||||
t.Skip()
|
||||
}
|
||||
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
|
||||
|
||||
id := uint(100)
|
||||
user := AssociationEmptyUser{
|
||||
ID: id,
|
||||
Name: "jinzhu",
|
||||
Pets: []AssociationEmptyPet{
|
||||
{AssociationEmptyUserID: &id, Name: "bar"},
|
||||
{AssociationEmptyUserID: &id, Name: "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create, got error: %v", err)
|
||||
}
|
||||
|
||||
var result AssociationEmptyUser
|
||||
err = DB.Preload("Pets").First(&result, &id).Error
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find, got error: %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
}
|
||||
|
@ -7,10 +7,10 @@ require (
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.7
|
||||
golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
|
||||
golang.org/x/text v0.3.8 // indirect
|
||||
golang.org/x/text v0.4.0 // indirect
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.4.4
|
||||
gorm.io/driver/sqlite v1.4.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
gorm.io/driver/sqlserver v1.4.1
|
||||
gorm.io/gorm v1.24.0
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
. "gorm.io/gorm/utils/tests"
|
||||
@ -72,6 +74,44 @@ func TestMigrate(t *testing.T) {
|
||||
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAutoMigrateInt8PG(t *testing.T) {
|
||||
if DB.Dialector.Name() != "postgres" {
|
||||
return
|
||||
}
|
||||
|
||||
type Smallint int8
|
||||
|
||||
type MigrateInt struct {
|
||||
Int8 Smallint
|
||||
}
|
||||
|
||||
tracer := Tracer{
|
||||
Logger: DB.Config.Logger,
|
||||
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
|
||||
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&MigrateInt{})
|
||||
|
||||
// The first AutoMigrate to make table with field with correct type
|
||||
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
|
||||
// make new session to set custom logger tracer
|
||||
session := DB.Session(&gorm.Session{Logger: tracer})
|
||||
|
||||
// The second AutoMigrate to catch an error
|
||||
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
|
||||
t.Fatalf("Failed to auto migrate: error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoMigrateSelfReferential(t *testing.T) {
|
||||
|
@ -2,8 +2,8 @@ package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) {
|
||||
t.Errorf("Failed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmtReset(t *testing.T) {
|
||||
tx := DB.Session(&gorm.Session{PrepareStmt: true})
|
||||
|
||||
user := *GetUser("prepared_stmt_reset", Config{})
|
||||
tx = tx.Create(&user)
|
||||
|
||||
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||
if !ok {
|
||||
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||
}
|
||||
|
||||
pdb.Mux.Lock()
|
||||
if len(pdb.Stmts) == 0 {
|
||||
pdb.Mux.Unlock()
|
||||
t.Fatalf("prepared stmt can not be empty")
|
||||
}
|
||||
pdb.Mux.Unlock()
|
||||
|
||||
pdb.Reset()
|
||||
pdb.Mux.Lock()
|
||||
defer pdb.Mux.Unlock()
|
||||
if len(pdb.Stmts) != 0 {
|
||||
t.Fatalf("prepared stmt should be empty")
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,8 @@ type SerializerStruct struct {
|
||||
gorm.Model
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Roles2 *Roles `gorm:"serializer:json"`
|
||||
Roles3 *Roles `gorm:"serializer:json;not null"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
|
||||
}
|
||||
|
||||
var result SerializerStruct
|
||||
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||
if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query data, got error %v", err)
|
||||
}
|
||||
|
||||
|
@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
|
||||
if DB.Statement.DryRun || DB.DryRun {
|
||||
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
|
||||
}
|
||||
|
||||
// UpdateColumns
|
||||
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
|
||||
})
|
||||
})
|
||||
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
|
||||
}
|
||||
|
||||
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.
|
||||
|
34
tests/tracer_test.go
Normal file
34
tests/tracer_test.go
Normal file
@ -0,0 +1,34 @@
|
||||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
Logger logger.Interface
|
||||
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
|
||||
}
|
||||
|
||||
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
|
||||
return S.Logger.LogMode(level)
|
||||
}
|
||||
|
||||
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Info(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Warn(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
|
||||
S.Logger.Error(ctx, s, i...)
|
||||
}
|
||||
|
||||
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
S.Logger.Trace(ctx, begin, fc, err)
|
||||
S.Test(ctx, begin, fc, err)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user