Refactor ParseWithSchemaTable method and improve test. (#4789)

* Refactor ParseWithSchemaTable method and improve test.

* Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test.

* Rename `schemaTable` to `specialTableName` for clearly argument.
This commit is contained in:
Jason Lee 2021-10-25 11:26:44 +08:00 committed by GitHub
parent 38e55f1117
commit d3211908a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 33 deletions

View File

@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
if table, ok := value.(string); ok { if table, ok := value.(string); ok {
stmt.Table = table stmt.Table = table
} else if err := stmt.Parse(value); err != nil { } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
return err return err
} }

View File

@ -73,15 +73,11 @@ type Tabler interface {
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return parse(dest, cacheStore, namer, "") return ParseWithSpecialTableName(dest, cacheStore, namer, "")
} }
// ParseWithSchemaTable get data type from dialector with extra schema table // ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
return parse(dest, cacheStore, namer, schemaTable)
}
func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) {
if dest == nil { if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
@ -107,7 +103,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
} }
if v, ok := cacheStore.Load(modelType); ok { // Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schmema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
@ -116,15 +122,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
modelValue := reflect.New(modelType) modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name()) tableName := namer.TableName(modelType.Name())
if schemaTable != "" {
tableName = schemaTable
}
if tabler, ok := modelValue.Interface().(Tabler); ok { if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName() tableName = tabler.TableName()
} }
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
schema := &Schema{ schema := &Schema{
Name: modelType.Name(), Name: modelType.Name(),
@ -140,7 +146,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
if v, loaded := cacheStore.Load(modelType); loaded { // Load exist schmema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
@ -247,13 +254,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
} }
} }
if schemaTable == "" { // Cache the schema
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
<-s.initialized <-s.initialized
return s, s.err return s, s.err
}
} }
defer func() { defer func() {

View File

@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) {
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1] stmt.Table = tables[1]

View File

@ -382,32 +382,41 @@ func TestMigrateConstraint(t *testing.T) {
} }
} }
type MigrateUser struct { type DynamicUser struct {
gorm.Model gorm.Model
Name string `gorm:"index"` Name string
CompanyID string `gorm:"index"`
} }
// To test auto migrate crate indexes for dynamic table name
// https://github.com/go-gorm/gorm/issues/4752 // https://github.com/go-gorm/gorm/issues/4752
func TestMigrateIndexesWithDynamicTableName(t *testing.T) { func TestMigrateIndexesWithDynamicTableName(t *testing.T) {
tableNameSuffixes := []string{"01", "02", "03"} // Create primary table
for _, v := range tableNameSuffixes { if err := DB.AutoMigrate(&DynamicUser{}); err != nil {
tableName := "migrate_user_" + v t.Fatalf("AutoMigrate create table error: %#v", err)
}
// Create sub tables
for _, v := range []string{"01", "02", "03"} {
tableName := "dynamic_users_" + v
m := DB.Scopes(func(db *gorm.DB) *gorm.DB { m := DB.Scopes(func(db *gorm.DB) *gorm.DB {
return db.Table(tableName) return db.Table(tableName)
}).Migrator() }).Migrator()
if err := m.AutoMigrate(&MigrateUser{}); err != nil { if err := m.AutoMigrate(&DynamicUser{}); err != nil {
t.Fatalf("Failed to create table for %#v", tableName) t.Fatalf("AutoMigrate create table error: %#v", err)
} }
if !m.HasTable(tableName) { if !m.HasTable(tableName) {
t.Fatalf("Failed to create table for %#v", tableName) t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName)
} }
if !m.HasIndex(&MigrateUser{}, "Name") {
t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) if !m.HasIndex(&DynamicUser{}, "CompanyID") {
t.Fatalf("Should have index on %s", "CompanyI.")
} }
if !m.HasIndex(&MigrateUser{}, "DeletedAt") {
t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) if !m.HasIndex(&DynamicUser{}, "DeletedAt") {
t.Fatalf("Should have index on deleted_at.")
} }
} }
} }