Refactor ParseWithSchemaTable method and improve test. (#4789)
* Refactor ParseWithSchemaTable method and improve test. * Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test. * Rename `schemaTable` to `specialTableName` for clearly argument.
This commit is contained in:
		
							parent
							
								
									38e55f1117
								
							
						
					
					
						commit
						d3211908a0
					
				| @ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error | |||||||
| 
 | 
 | ||||||
| 	if table, ok := value.(string); ok { | 	if table, ok := value.(string); ok { | ||||||
| 		stmt.Table = table | 		stmt.Table = table | ||||||
| 	} else if err := stmt.Parse(value); err != nil { | 	} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -73,15 +73,11 @@ type Tabler interface { | |||||||
| 
 | 
 | ||||||
| // Parse get data type from dialector
 | // Parse get data type from dialector
 | ||||||
| func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { | ||||||
| 	return parse(dest, cacheStore, namer, "") | 	return ParseWithSpecialTableName(dest, cacheStore, namer, "") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ParseWithSchemaTable get data type from dialector with extra schema table
 | // ParseWithSpecialTableName get data type from dialector with extra schema table
 | ||||||
| func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { | func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { | ||||||
| 	return parse(dest, cacheStore, namer, schemaTable) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { |  | ||||||
| 	if dest == nil { | 	if dest == nil { | ||||||
| 		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | 		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) | ||||||
| 	} | 	} | ||||||
| @ -107,7 +103,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri | |||||||
| 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := cacheStore.Load(modelType); ok { | 	// Cache the Schema for performance,
 | ||||||
|  | 	// Use the modelType or modelType + schemaTable (if it present) as cache key.
 | ||||||
|  | 	var schemaCacheKey interface{} | ||||||
|  | 	if specialTableName != "" { | ||||||
|  | 		schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) | ||||||
|  | 	} else { | ||||||
|  | 		schemaCacheKey = modelType | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Load exist schmema cache, return if exists
 | ||||||
|  | 	if v, ok := cacheStore.Load(schemaCacheKey); ok { | ||||||
| 		s := v.(*Schema) | 		s := v.(*Schema) | ||||||
| 		// Wait for the initialization of other goroutines to complete
 | 		// Wait for the initialization of other goroutines to complete
 | ||||||
| 		<-s.initialized | 		<-s.initialized | ||||||
| @ -116,15 +122,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri | |||||||
| 
 | 
 | ||||||
| 	modelValue := reflect.New(modelType) | 	modelValue := reflect.New(modelType) | ||||||
| 	tableName := namer.TableName(modelType.Name()) | 	tableName := namer.TableName(modelType.Name()) | ||||||
| 	if schemaTable != "" { |  | ||||||
| 		tableName = schemaTable |  | ||||||
| 	} |  | ||||||
| 	if tabler, ok := modelValue.Interface().(Tabler); ok { | 	if tabler, ok := modelValue.Interface().(Tabler); ok { | ||||||
| 		tableName = tabler.TableName() | 		tableName = tabler.TableName() | ||||||
| 	} | 	} | ||||||
| 	if en, ok := namer.(embeddedNamer); ok { | 	if en, ok := namer.(embeddedNamer); ok { | ||||||
| 		tableName = en.Table | 		tableName = en.Table | ||||||
| 	} | 	} | ||||||
|  | 	if specialTableName != "" && specialTableName != tableName { | ||||||
|  | 		tableName = specialTableName | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	schema := &Schema{ | 	schema := &Schema{ | ||||||
| 		Name:           modelType.Name(), | 		Name:           modelType.Name(), | ||||||
| @ -140,7 +146,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri | |||||||
| 	// When the schema initialization is completed, the channel will be closed
 | 	// When the schema initialization is completed, the channel will be closed
 | ||||||
| 	defer close(schema.initialized) | 	defer close(schema.initialized) | ||||||
| 
 | 
 | ||||||
| 	if v, loaded := cacheStore.Load(modelType); loaded { | 	// Load exist schmema cache, return if exists
 | ||||||
|  | 	if v, ok := cacheStore.Load(schemaCacheKey); ok { | ||||||
| 		s := v.(*Schema) | 		s := v.(*Schema) | ||||||
| 		// Wait for the initialization of other goroutines to complete
 | 		// Wait for the initialization of other goroutines to complete
 | ||||||
| 		<-s.initialized | 		<-s.initialized | ||||||
| @ -247,13 +254,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if schemaTable == "" { | 	// Cache the schema
 | ||||||
| 		if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { | 	if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { | ||||||
| 			s := v.(*Schema) | 		s := v.(*Schema) | ||||||
| 			// Wait for the initialization of other goroutines to complete
 | 		// Wait for the initialization of other goroutines to complete
 | ||||||
| 			<-s.initialized | 		<-s.initialized | ||||||
| 			return s, s.err | 		return s, s.err | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	defer func() { | 	defer func() { | ||||||
|  | |||||||
| @ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *Statement) Parse(value interface{}) (err error) { | func (stmt *Statement) Parse(value interface{}) (err error) { | ||||||
| 	if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { | 	return stmt.ParseWithSpecialTableName(value, "") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { | ||||||
|  | 	if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { | ||||||
| 		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { | 		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { | ||||||
| 			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} | 			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} | ||||||
| 			stmt.Table = tables[1] | 			stmt.Table = tables[1] | ||||||
|  | |||||||
| @ -382,32 +382,41 @@ func TestMigrateConstraint(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type MigrateUser struct { | type DynamicUser struct { | ||||||
| 	gorm.Model | 	gorm.Model | ||||||
| 	Name string `gorm:"index"` | 	Name      string | ||||||
|  | 	CompanyID string `gorm:"index"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // To test auto migrate crate indexes for dynamic table name
 | ||||||
| // https://github.com/go-gorm/gorm/issues/4752
 | // https://github.com/go-gorm/gorm/issues/4752
 | ||||||
| func TestMigrateIndexesWithDynamicTableName(t *testing.T) { | func TestMigrateIndexesWithDynamicTableName(t *testing.T) { | ||||||
| 	tableNameSuffixes := []string{"01", "02", "03"} | 	// Create primary table
 | ||||||
| 	for _, v := range tableNameSuffixes { | 	if err := DB.AutoMigrate(&DynamicUser{}); err != nil { | ||||||
| 		tableName := "migrate_user_" + v | 		t.Fatalf("AutoMigrate create table error: %#v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Create sub tables
 | ||||||
|  | 	for _, v := range []string{"01", "02", "03"} { | ||||||
|  | 		tableName := "dynamic_users_" + v | ||||||
| 		m := DB.Scopes(func(db *gorm.DB) *gorm.DB { | 		m := DB.Scopes(func(db *gorm.DB) *gorm.DB { | ||||||
| 			return db.Table(tableName) | 			return db.Table(tableName) | ||||||
| 		}).Migrator() | 		}).Migrator() | ||||||
| 
 | 
 | ||||||
| 		if err := m.AutoMigrate(&MigrateUser{}); err != nil { | 		if err := m.AutoMigrate(&DynamicUser{}); err != nil { | ||||||
| 			t.Fatalf("Failed to create table for %#v", tableName) | 			t.Fatalf("AutoMigrate create table error: %#v", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !m.HasTable(tableName) { | 		if !m.HasTable(tableName) { | ||||||
| 			t.Fatalf("Failed to create table for %#v", tableName) | 			t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) | ||||||
| 		} | 		} | ||||||
| 		if !m.HasIndex(&MigrateUser{}, "Name") { | 
 | ||||||
| 			t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) | 		if !m.HasIndex(&DynamicUser{}, "CompanyID") { | ||||||
|  | 			t.Fatalf("Should have index on %s", "CompanyI.") | ||||||
| 		} | 		} | ||||||
| 		if !m.HasIndex(&MigrateUser{}, "DeletedAt") { | 
 | ||||||
| 			t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) | 		if !m.HasIndex(&DynamicUser{}, "DeletedAt") { | ||||||
|  | 			t.Fatalf("Should have index on deleted_at.") | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jason Lee
						Jason Lee