diff --git a/main.go b/main.go index 67e5f58e..7294de9a 100644 --- a/main.go +++ b/main.go @@ -518,6 +518,11 @@ func (s *DB) Table(name string) *DB { return clone } +// Table specify the table you would like to run db operations +func (s *DB) ParameterisedTable(query string, args ...interface{}) *DB { + return s.clone().search.ParameterisedTable(query, args...).db +} + // Debug start debug mode func (s *DB) Debug() *DB { return s.clone().LogMode(true) diff --git a/scope.go b/scope.go index 541fe522..9322cdb6 100644 --- a/scope.go +++ b/scope.go @@ -329,8 +329,12 @@ type dbTabler interface { // TableName return table name func (scope *Scope) TableName() string { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - return scope.Search.tableName + if scope.Search != nil && scope.Search.tableName != nil { + if str, ok := scope.Search.tableName.(string); ok { + return str + } else if exp, ok := scope.Search.tableName.(*expr); ok { + return exp.expr + } } if tabler, ok := scope.Value.(tabler); ok { @@ -346,11 +350,18 @@ func (scope *Scope) TableName() string { // QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { - if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Contains(scope.Search.tableName, " ") { - return scope.Search.tableName + if scope.Search != nil && scope.Search.tableName != nil { + var tableName string + if str, ok := scope.Search.tableName.(string); ok { + tableName = str + } else if exp, ok := scope.Search.tableName.(*expr); ok { + tableName = exp.expr } - return scope.Quote(scope.Search.tableName) + + if strings.Contains(tableName, " ") { + return tableName + } + return scope.Quote(tableName) } return scope.Quote(scope.TableName()) @@ -859,11 +870,25 @@ func (scope *Scope) joinsSQL() string { return strings.Join(joinConditions, " ") + " " } +func (scope *Scope) tableSQL() string { + if str, ok := scope.Search.tableName.(string); ok { + return str + } else if expr, ok := scope.Search.tableName.(*expr); ok { + exp := expr.expr + for _, arg := range expr.args { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } + return exp + } else { + return scope.QuotedTableName() + } +} + func (scope *Scope) prepareQuerySQL() { if scope.Search.raw { scope.Raw(scope.CombinedConditionSql()) } else { - scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) + scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.tableSQL(), scope.CombinedConditionSql())) } return } diff --git a/search.go b/search.go index 90138595..3318221c 100644 --- a/search.go +++ b/search.go @@ -20,7 +20,7 @@ type search struct { offset interface{} limit interface{} group string - tableName string + tableName interface{} raw bool Unscoped bool ignoreOrderQuery bool @@ -133,6 +133,19 @@ func (s *search) unscoped() *search { return s } +//func (s *search) Table(name string) *search { +// s.tableName = name +// return s +//} + +func (s *search) ParameterisedTable(query string, values ...interface{}) *search { + s.tableName = &expr{ + expr: query, + args: values, + } + return s +} + func (s *search) Table(name string) *search { s.tableName = name return s