diff --git a/migrator/migrator.go b/migrator/migrator.go index 48db151e..f594dde3 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - } else if err := stmt.Parse(value); err != nil { + } else if err := stmt.ParseWithTableName(value, stmt.Table); err != nil { return err } diff --git a/schema/schema.go b/schema/schema.go index 7e9431c9..4e863aa4 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -107,7 +107,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - if v, ok := cacheStore.Load(modelType); ok { + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if schemaTable != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", &modelType, schemaTable) + } else { + schemaCacheKey = modelType + } + + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -115,18 +125,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri } modelValue := reflect.New(modelType) - - // schemaTable for assignment table name directly - tableName := schemaTable - if schemaTable == "" { - tableName = namer.TableName(modelType.Name()) - - if tabler, ok := modelValue.Interface().(Tabler); ok { - tableName = tabler.TableName() - } - if en, ok := namer.(embeddedNamer); ok { - tableName = en.Table - } + tableName := namer.TableName(modelType.Name()) + if tabler, ok := modelValue.Interface().(Tabler); ok { + tableName = tabler.TableName() + } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } + if schemaTable != "" && schemaTable != tableName { + tableName = schemaTable } schema := &Schema{ @@ -143,7 +150,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.Load(modelType); loaded { + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -250,13 +258,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri } } - if schemaTable == "" { - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - // Wait for the initialization of other goroutines to complete - <-s.initialized - return s, s.err - } + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err } defer func() { diff --git a/statement.go b/statement.go index bbe00106..a9b6f5fc 100644 --- a/statement.go +++ b/statement.go @@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { + return stmt.ParseWithTableName(value, "") +} + +func (stmt *Statement) ParseWithTableName(value interface{}, schemaTable string) (err error) { + if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, schemaTable); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 312a5c37..dcd2794c 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,10 +4,10 @@ import ( "database/sql/driver" "encoding/json" "errors" - "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" + "testing" ) func TestEmbeddedStruct(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod index e18dc1dc..7906f575 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,7 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 + github.com/mattn/go-sqlite3 v1.14.9 // indirect gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.2 gorm.io/driver/sqlite v1.1.6 diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 8d2f474d..0354e84e 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -398,7 +398,7 @@ func TestMigrateIndexesWithDynamicTableName(t *testing.T) { // Create sub tables for _, v := range []string{"01", "02", "03"} { - tableName := "users_" + v + tableName := "dynamic_users_" + v m := DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table(tableName) }).Migrator()