Refactor ParseWithSchemaTable method and improve test. (#4789)
* Refactor ParseWithSchemaTable method and improve test. * Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test. * Rename `schemaTable` to `specialTableName` for clearly argument.
This commit is contained in:
parent
38e55f1117
commit
d3211908a0
@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
|
|||||||
|
|
||||||
if table, ok := value.(string); ok {
|
if table, ok := value.(string); ok {
|
||||||
stmt.Table = table
|
stmt.Table = table
|
||||||
} else if err := stmt.Parse(value); err != nil {
|
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,15 +73,11 @@ type Tabler interface {
|
|||||||
|
|
||||||
// Parse get data type from dialector
|
// Parse get data type from dialector
|
||||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||||
return parse(dest, cacheStore, namer, "")
|
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWithSchemaTable get data type from dialector with extra schema table
|
// ParseWithSpecialTableName get data type from dialector with extra schema table
|
||||||
func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) {
|
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
|
||||||
return parse(dest, cacheStore, namer, schemaTable)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) {
|
|
||||||
if dest == nil {
|
if dest == nil {
|
||||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||||
}
|
}
|
||||||
@ -107,7 +103,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
|
|||||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v, ok := cacheStore.Load(modelType); ok {
|
// Cache the Schema for performance,
|
||||||
|
// Use the modelType or modelType + schemaTable (if it present) as cache key.
|
||||||
|
var schemaCacheKey interface{}
|
||||||
|
if specialTableName != "" {
|
||||||
|
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
|
||||||
|
} else {
|
||||||
|
schemaCacheKey = modelType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load exist schmema cache, return if exists
|
||||||
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
<-s.initialized
|
<-s.initialized
|
||||||
@ -116,15 +122,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
|
|||||||
|
|
||||||
modelValue := reflect.New(modelType)
|
modelValue := reflect.New(modelType)
|
||||||
tableName := namer.TableName(modelType.Name())
|
tableName := namer.TableName(modelType.Name())
|
||||||
if schemaTable != "" {
|
|
||||||
tableName = schemaTable
|
|
||||||
}
|
|
||||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||||
tableName = tabler.TableName()
|
tableName = tabler.TableName()
|
||||||
}
|
}
|
||||||
if en, ok := namer.(embeddedNamer); ok {
|
if en, ok := namer.(embeddedNamer); ok {
|
||||||
tableName = en.Table
|
tableName = en.Table
|
||||||
}
|
}
|
||||||
|
if specialTableName != "" && specialTableName != tableName {
|
||||||
|
tableName = specialTableName
|
||||||
|
}
|
||||||
|
|
||||||
schema := &Schema{
|
schema := &Schema{
|
||||||
Name: modelType.Name(),
|
Name: modelType.Name(),
|
||||||
@ -140,7 +146,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
|
|||||||
// When the schema initialization is completed, the channel will be closed
|
// When the schema initialization is completed, the channel will be closed
|
||||||
defer close(schema.initialized)
|
defer close(schema.initialized)
|
||||||
|
|
||||||
if v, loaded := cacheStore.Load(modelType); loaded {
|
// Load exist schmema cache, return if exists
|
||||||
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
<-s.initialized
|
<-s.initialized
|
||||||
@ -247,14 +254,13 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if schemaTable == "" {
|
// Cache the schema
|
||||||
if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded {
|
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
<-s.initialized
|
<-s.initialized
|
||||||
return s, s.err
|
return s, s.err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if schema.err != nil {
|
if schema.err != nil {
|
||||||
|
@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||||
if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" {
|
return stmt.ParseWithSpecialTableName(value, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
||||||
|
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
||||||
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
||||||
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
||||||
stmt.Table = tables[1]
|
stmt.Table = tables[1]
|
||||||
|
@ -382,32 +382,41 @@ func TestMigrateConstraint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type MigrateUser struct {
|
type DynamicUser struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"index"`
|
Name string
|
||||||
|
CompanyID string `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To test auto migrate crate indexes for dynamic table name
|
||||||
// https://github.com/go-gorm/gorm/issues/4752
|
// https://github.com/go-gorm/gorm/issues/4752
|
||||||
func TestMigrateIndexesWithDynamicTableName(t *testing.T) {
|
func TestMigrateIndexesWithDynamicTableName(t *testing.T) {
|
||||||
tableNameSuffixes := []string{"01", "02", "03"}
|
// Create primary table
|
||||||
for _, v := range tableNameSuffixes {
|
if err := DB.AutoMigrate(&DynamicUser{}); err != nil {
|
||||||
tableName := "migrate_user_" + v
|
t.Fatalf("AutoMigrate create table error: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create sub tables
|
||||||
|
for _, v := range []string{"01", "02", "03"} {
|
||||||
|
tableName := "dynamic_users_" + v
|
||||||
m := DB.Scopes(func(db *gorm.DB) *gorm.DB {
|
m := DB.Scopes(func(db *gorm.DB) *gorm.DB {
|
||||||
return db.Table(tableName)
|
return db.Table(tableName)
|
||||||
}).Migrator()
|
}).Migrator()
|
||||||
|
|
||||||
if err := m.AutoMigrate(&MigrateUser{}); err != nil {
|
if err := m.AutoMigrate(&DynamicUser{}); err != nil {
|
||||||
t.Fatalf("Failed to create table for %#v", tableName)
|
t.Fatalf("AutoMigrate create table error: %#v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.HasTable(tableName) {
|
if !m.HasTable(tableName) {
|
||||||
t.Fatalf("Failed to create table for %#v", tableName)
|
t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName)
|
||||||
}
|
}
|
||||||
if !m.HasIndex(&MigrateUser{}, "Name") {
|
|
||||||
t.Fatalf("Should find index for %s's name after AutoMigrate", tableName)
|
if !m.HasIndex(&DynamicUser{}, "CompanyID") {
|
||||||
|
t.Fatalf("Should have index on %s", "CompanyI.")
|
||||||
}
|
}
|
||||||
if !m.HasIndex(&MigrateUser{}, "DeletedAt") {
|
|
||||||
t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName)
|
if !m.HasIndex(&DynamicUser{}, "DeletedAt") {
|
||||||
|
t.Fatalf("Should have index on deleted_at.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user