Merge remote-tracking branch 'upstream/master' into test_invalid_cache_preprare

This commit is contained in:
a631807682 2022-10-19 15:08:09 +08:00
commit 4c5d727449
No known key found for this signature in database
GPG Key ID: 137D1D75522168AB
12 changed files with 213 additions and 34 deletions

View File

@ -208,7 +208,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr { if isPtr {
elems = reflect.Append(elems, elem) elems = reflect.Append(elems, elem)
} else { } else {
@ -294,7 +297,10 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
cacheKey := utils.ToStringKey(relPrimaryValues...) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem) distinctElems = reflect.Append(distinctElems, elem)
} }

View File

@ -406,17 +406,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName()) realDataType := strings.ToLower(columnType.DatabaseTypeName())
alterColumn := false var (
alterColumn, isSameType bool
)
if !field.PrimaryKey { if !field.PrimaryKey {
// check type // check type
var isSameType bool if !strings.HasPrefix(fullDataType, realDataType) {
if strings.HasPrefix(fullDataType, realDataType) { // check type aliases
isSameType = true
}
// check type aliases
if !isSameType {
aliases := m.DB.Migrator().GetTypeAliases(realDataType) aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases { for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) { if strings.HasPrefix(fullDataType, alias) {
@ -424,32 +421,34 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
break break
} }
} }
}
if !isSameType { if !isSameType {
alterColumn = true
}
}
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true alterColumn = true
} }
} }
} }
// check precision if !isSameType {
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { // check size
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { if length, ok := columnType.Length(); length != int64(field.Size) {
alterColumn = true if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
} }
} }

View File

@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() {
} }
} }
func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
for query, stmt := range db.Stmts {
delete(db.Stmts, query)
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = map[string](*Stmt){}
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {

View File

@ -100,6 +100,12 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
// Value implements serializer interface // Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue) result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err return string(result), err
} }

View File

@ -179,6 +179,10 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
} else { } else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
} }
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression: case clause.Expression:
v.Build(stmt) v.Build(stmt)
case driver.Valuer: case driver.Valuer:

View File

@ -348,3 +348,45 @@ func TestAssociationEmptyQueryClause(t *testing.T) {
AssertEqual(t, len(orgs), 0) AssertEqual(t, len(orgs), 0)
} }
} }
type AssociationEmptyUser struct {
ID uint
Name string
Pets []AssociationEmptyPet
}
type AssociationEmptyPet struct {
AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"`
Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"`
}
func TestAssociationEmptyPrimaryKey(t *testing.T) {
if DB.Dialector.Name() != "mysql" {
t.Skip()
}
DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{})
DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{})
id := uint(100)
user := AssociationEmptyUser{
ID: id,
Name: "jinzhu",
Pets: []AssociationEmptyPet{
{AssociationEmptyUserID: &id, Name: "bar"},
{AssociationEmptyUserID: &id, Name: "foo"},
},
}
err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error
if err != nil {
t.Fatalf("Failed to create, got error: %v", err)
}
var result AssociationEmptyUser
err = DB.Preload("Pets").First(&result, &id).Error
if err != nil {
t.Fatalf("Failed to find, got error: %v", err)
}
AssertEqual(t, result, user)
}

View File

@ -7,10 +7,10 @@ require (
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.7 github.com/lib/pq v1.10.7
golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
golang.org/x/text v0.3.8 // indirect golang.org/x/text v0.4.0 // indirect
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.4.4 gorm.io/driver/postgres v1.4.4
gorm.io/driver/sqlite v1.4.2 gorm.io/driver/sqlite v1.4.3
gorm.io/driver/sqlserver v1.4.1 gorm.io/driver/sqlserver v1.4.1
gorm.io/gorm v1.24.0 gorm.io/gorm v1.24.0
) )

View File

@ -1,6 +1,7 @@
package tests_test package tests_test
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"reflect" "reflect"
@ -9,6 +10,7 @@ import (
"time" "time"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
@ -72,6 +74,44 @@ func TestMigrate(t *testing.T) {
t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1])
} }
} }
}
func TestAutoMigrateInt8PG(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
return
}
type Smallint int8
type MigrateInt struct {
Int8 Smallint
}
tracer := Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") {
t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql)
}
},
}
DB.Migrator().DropTable(&MigrateInt{})
// The first AutoMigrate to make table with field with correct type
if err := DB.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
// make new session to set custom logger tracer
session := DB.Session(&gorm.Session{Logger: tracer})
// The second AutoMigrate to catch an error
if err := session.AutoMigrate(&MigrateInt{}); err != nil {
t.Fatalf("Failed to auto migrate: error: %v", err)
}
} }
func TestAutoMigrateSelfReferential(t *testing.T) { func TestAutoMigrateSelfReferential(t *testing.T) {

View File

@ -2,8 +2,8 @@ package tests_test
import ( import (
"context" "context"
"sync"
"errors" "errors"
"sync"
"testing" "testing"
"time" "time"
@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) {
t.Errorf("Failed, got error: %v", err) t.Errorf("Failed, got error: %v", err)
} }
} }
func TestPreparedStmtReset(t *testing.T) {
tx := DB.Session(&gorm.Session{PrepareStmt: true})
user := *GetUser("prepared_stmt_reset", Config{})
tx = tx.Create(&user)
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
if !ok {
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
}
pdb.Mux.Lock()
if len(pdb.Stmts) == 0 {
pdb.Mux.Unlock()
t.Fatalf("prepared stmt can not be empty")
}
pdb.Mux.Unlock()
pdb.Reset()
pdb.Mux.Lock()
defer pdb.Mux.Unlock()
if len(pdb.Stmts) != 0 {
t.Fatalf("prepared stmt should be empty")
}
}

View File

@ -18,6 +18,8 @@ type SerializerStruct struct {
gorm.Model gorm.Model
Name []byte `gorm:"json"` Name []byte `gorm:"json"`
Roles Roles `gorm:"serializer:json"` Roles Roles `gorm:"serializer:json"`
Roles2 *Roles `gorm:"serializer:json"`
Roles3 *Roles `gorm:"serializer:json;not null"`
Contracts map[string]interface{} `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"`
JobInfo Job `gorm:"type:bytes;serializer:gob"` JobInfo Job `gorm:"type:bytes;serializer:gob"`
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
@ -108,7 +110,7 @@ func TestSerializer(t *testing.T) {
} }
var result SerializerStruct var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil { if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err) t.Fatalf("failed to query data, got error %v", err)
} }

View File

@ -445,6 +445,14 @@ func TestToSQL(t *testing.T) {
if DB.Statement.DryRun || DB.DryRun { if DB.Statement.DryRun || DB.DryRun {
t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false")
} }
// UpdateColumns
sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Raw("SELECT * FROM users ?", clause.OrderBy{
Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}},
})
})
assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql)
} }
// assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials.

34
tests/tracer_test.go Normal file
View File

@ -0,0 +1,34 @@
package tests_test
import (
"context"
"time"
"gorm.io/gorm/logger"
)
type Tracer struct {
Logger logger.Interface
Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
func (S Tracer) LogMode(level logger.LogLevel) logger.Interface {
return S.Logger.LogMode(level)
}
func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) {
S.Logger.Info(ctx, s, i...)
}
func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) {
S.Logger.Warn(ctx, s, i...)
}
func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) {
S.Logger.Error(ctx, s, i...)
}
func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
S.Logger.Trace(ctx, begin, fc, err)
S.Test(ctx, begin, fc, err)
}