diff --git a/main.go b/main.go index 97cff7db..0384c704 100644 --- a/main.go +++ b/main.go @@ -446,6 +446,12 @@ func (s *DB) Table(name string) *DB { return clone } +func (s *DB) Prefix(name string) *DB { + clone := s.clone() + clone.search.Prefix(name) + return clone +} + // Debug start debug mode func (s *DB) Debug() *DB { return s.clone().LogMode(true) diff --git a/main_test.go b/main_test.go index 3b1433cf..c90c4d1e 100644 --- a/main_test.go +++ b/main_test.go @@ -259,7 +259,19 @@ func TestTableName(t *testing.T) { if DB.NewScope([]Cart{}).TableName() != "shopping_cart" { t.Errorf("[]Cart's singular table name should be shopping_cart") } + + // Test prefix + if DB.Prefix("my_prefix_").NewScope(Order{}).TableName() != "my_prefix_order" { + t.Errorf("Order's table name should be my_prefix_order") + } + DB.SingularTable(false) + + // Test prefix + if DB.Prefix("my_prefix_").NewScope(Order{}).TableName() != "my_prefix_orders" { + t.Errorf("Order's table name should be my_prefix_orders") + } + } func TestNullValues(t *testing.T) { diff --git a/scope.go b/scope.go index 9a237998..d447efaf 100644 --- a/scope.go +++ b/scope.go @@ -301,17 +301,29 @@ type dbTabler interface { // TableName return table name func (scope *Scope) TableName() string { if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName + return scope.Search.prefix + scope.Search.tableName } if tabler, ok := scope.Value.(tabler); ok { + if scope.Search != nil { + return scope.Search.prefix + tabler.TableName() + } + return tabler.TableName() } if tabler, ok := scope.Value.(dbTabler); ok { + if scope.Search != nil { + return scope.Search.prefix + tabler.TableName(scope.db) + } + return tabler.TableName(scope.db) } + if scope.Search != nil { + return scope.Search.prefix + scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) + } + return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) } @@ -319,9 +331,9 @@ func (scope *Scope) TableName() string { func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { if strings.Index(scope.Search.tableName, " ") != -1 { - return scope.Search.tableName + return scope.Search.prefix + scope.Search.tableName } - return scope.Quote(scope.Search.tableName) + return scope.Quote(scope.Search.prefix + scope.Search.tableName) } return scope.Quote(scope.TableName()) diff --git a/search.go b/search.go index 23dac2c3..5f23954c 100644 --- a/search.go +++ b/search.go @@ -22,6 +22,7 @@ type search struct { limit interface{} group string tableName string + prefix string raw bool Unscoped bool ignoreOrderQuery bool @@ -141,6 +142,11 @@ func (s *search) Table(name string) *search { return s } +func (s *search) Prefix(name string) *search { + s.prefix = name + return s +} + func (s *search) getInterfaceAsSQL(value interface{}) (str string) { switch value.(type) { case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: