From acbe3e367214f438d460d368ce761c3dcaad1858 Mon Sep 17 00:00:00 2001 From: zhenghaotian Date: Thu, 29 May 2025 21:13:34 +0800 Subject: [PATCH] feat: Add table name separator configuration function to support multi-level database table names --- gorm.go | 17 +++++++++++++---- statement.go | 5 +++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/gorm.go b/gorm.go index 67889262..6fe114f2 100644 --- a/gorm.go +++ b/gorm.go @@ -17,6 +17,9 @@ import ( // for Config.cacheStore store PreparedStmtDB key const preparedStmtDBKey = "preparedStmt" +// DefaultTableNameSplit default table name split +var DefaultTableNameSplit = 2 + // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity @@ -73,6 +76,11 @@ type Config struct { cacheStore *sync.Map } +// SetTableNameSplit set default table name split +func SetTableNameSplit(split int) { + DefaultTableNameSplit = split +} + // Apply update config to new config func (c *Config) Apply(config *Config) error { if config != c { @@ -218,10 +226,11 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } db.Statement = &Statement{ - DB: db, - ConnPool: db.ConnPool, - Context: context.Background(), - Clauses: map[string]clause.Clause{}, + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + TableNameSplit: DefaultTableNameSplit, } if err == nil && !config.DisableAutomaticPing { diff --git a/statement.go b/statement.go index c6183724..e219b47b 100644 --- a/statement.go +++ b/statement.go @@ -48,6 +48,7 @@ type Statement struct { assigns []interface{} scopes []func(*DB) *DB Result *result + TableNameSplit int } type join struct { @@ -503,9 +504,9 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { - if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { + if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == stmt.TableNameSplit { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} - stmt.Table = tables[1] + stmt.Table = tables[stmt.TableNameSplit-1] return }