From e3649fd45d8d7f0bb3243079525b391989c2b4a4 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Thu, 21 Oct 2021 18:11:03 +0800 Subject: [PATCH] Rename `schemaTable` to `specialTableName` for clearly argument. --- migrator/migrator.go | 2 +- schema/schema.go | 18 +++++++----------- statement.go | 6 +++--- tests/embedded_struct_test.go | 2 +- tests/go.mod | 1 - 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index f594dde3..30586a8c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - } else if err := stmt.ParseWithTableName(value, stmt.Table); err != nil { + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } diff --git a/schema/schema.go b/schema/schema.go index 4e863aa4..ce7cf3b1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,15 +73,11 @@ type Tabler interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - return parse(dest, cacheStore, namer, "") + return ParseWithSpecialTableName(dest, cacheStore, namer, "") } -// ParseWithSchemaTable get data type from dialector with extra schema table -func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { - return parse(dest, cacheStore, namer, schemaTable) -} - -func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -110,8 +106,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. var schemaCacheKey interface{} - if schemaTable != "" { - schemaCacheKey = fmt.Sprintf("%p-%s", &modelType, schemaTable) + if specialTableName != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) } else { schemaCacheKey = modelType } @@ -132,8 +128,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } - if schemaTable != "" && schemaTable != tableName { - tableName = schemaTable + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName } schema := &Schema{ diff --git a/statement.go b/statement.go index a9b6f5fc..85432e48 100644 --- a/statement.go +++ b/statement.go @@ -456,11 +456,11 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - return stmt.ParseWithTableName(value, "") + return stmt.ParseWithSpecialTableName(value, "") } -func (stmt *Statement) ParseWithTableName(value interface{}, schemaTable string) (err error) { - if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, schemaTable); err == nil && stmt.Table == "" { +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index dcd2794c..312a5c37 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -4,10 +4,10 @@ import ( "database/sql/driver" "encoding/json" "errors" + "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" - "testing" ) func TestEmbeddedStruct(t *testing.T) { diff --git a/tests/go.mod b/tests/go.mod index 7906f575..e18dc1dc 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,7 +6,6 @@ require ( github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.10.3 - github.com/mattn/go-sqlite3 v1.14.9 // indirect gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.2 gorm.io/driver/sqlite v1.1.6