Rename schemaTable to specialTableName for clearly argument.

This commit is contained in:
Jason Lee 2021-10-21 18:11:03 +08:00
parent b1626b1b46
commit e3649fd45d
5 changed files with 12 additions and 17 deletions

View File

@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
if table, ok := value.(string); ok { if table, ok := value.(string); ok {
stmt.Table = table stmt.Table = table
} else if err := stmt.ParseWithTableName(value, stmt.Table); err != nil { } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
return err return err
} }

View File

@ -73,15 +73,11 @@ type Tabler interface {
// Parse get data type from dialector // Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return parse(dest, cacheStore, namer, "") return ParseWithSpecialTableName(dest, cacheStore, namer, "")
} }
// ParseWithSchemaTable get data type from dialector with extra schema table // ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
return parse(dest, cacheStore, namer, schemaTable)
}
func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) {
if dest == nil { if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
} }
@ -110,8 +106,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
// Cache the Schema for performance, // Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key. // Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{} var schemaCacheKey interface{}
if schemaTable != "" { if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", &modelType, schemaTable) schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else { } else {
schemaCacheKey = modelType schemaCacheKey = modelType
} }
@ -132,8 +128,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri
if en, ok := namer.(embeddedNamer); ok { if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table tableName = en.Table
} }
if schemaTable != "" && schemaTable != tableName { if specialTableName != "" && specialTableName != tableName {
tableName = schemaTable tableName = specialTableName
} }
schema := &Schema{ schema := &Schema{

View File

@ -456,11 +456,11 @@ func (stmt *Statement) Build(clauses ...string) {
} }
func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithTableName(value, "") return stmt.ParseWithSpecialTableName(value, "")
} }
func (stmt *Statement) ParseWithTableName(value interface{}, schemaTable string) (err error) { func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, schemaTable); err == nil && stmt.Table == "" { if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1] stmt.Table = tables[1]

View File

@ -4,10 +4,10 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"testing"
"gorm.io/gorm" "gorm.io/gorm"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
"testing"
) )
func TestEmbeddedStruct(t *testing.T) { func TestEmbeddedStruct(t *testing.T) {

View File

@ -6,7 +6,6 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.2 github.com/jinzhu/now v1.1.2
github.com/lib/pq v1.10.3 github.com/lib/pq v1.10.3
github.com/mattn/go-sqlite3 v1.14.9 // indirect
gorm.io/driver/mysql v1.1.2 gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.1.2 gorm.io/driver/postgres v1.1.2
gorm.io/driver/sqlite v1.1.6 gorm.io/driver/sqlite v1.1.6