Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test.

This commit is contained in:
Jason Lee 2021-10-20 20:48:17 +08:00
parent 09483e8928
commit b1626b1b46
6 changed files with 37 additions and 25 deletions

View File

@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.Parse(value); err != nil {
} else if err := stmt.ParseWithTableName(value, stmt.Table); err != nil {
return err
}

View File

@ -107,7 +107,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if schemaTable != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", &modelType, schemaTable)
} else {
schemaCacheKey = modelType
}
// Load exist schmema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
@ -115,18 +125,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
}
modelValue := reflect.New(modelType)
// schemaTable for assignment table name directly
tableName := schemaTable
if schemaTable == "" {
tableName = namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if schemaTable != "" && schemaTable != tableName {
tableName = schemaTable
}
schema := &Schema{
@ -143,7 +150,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
if v, loaded := cacheStore.Load(modelType); loaded {
// Load exist schmema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
@ -250,13 +258,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
}
}
if schemaTable == "" {
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {

View File

@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) {
}
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" {
return stmt.ParseWithTableName(value, "")
}
func (stmt *Statement) ParseWithTableName(value interface{}, schemaTable string) (err error) {
if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, schemaTable); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]

View File

@ -4,10 +4,10 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"testing"
"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
"testing"
)
func TestEmbeddedStruct(t *testing.T) {

View File

@ -6,6 +6,7 @@ require (
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3
github.com/mattn/go-sqlite3 v1.14.9 // indirect
gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.1.2
gorm.io/driver/sqlite v1.1.6

View File

@ -398,7 +398,7 @@ func TestMigrateIndexesWithDynamicTableName(t *testing.T) {
// Create sub tables
for _, v := range []string{"01", "02", "03"} {
tableName := "users_" + v
tableName := "dynamic_users_" + v
m := DB.Scopes(func(db *gorm.DB) *gorm.DB {
return db.Table(tableName)
}).Migrator()