Update sqlbuilder

This commit is contained in:
Jinzhu 2018-03-01 01:12:29 +08:00
parent 8b567b49d0
commit 23457d28ce
6 changed files with 174 additions and 59 deletions

2
api.go
View File

@ -10,7 +10,7 @@ func (s *DB) Where(query interface{}, args ...interface{}) *DB {
// Not add NOT condition
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
tx := s.init()
tx.Statement.AddConditions(Not(tx.Statement.BuildCondition(query, args...)))
tx.Statement.AddConditions(Not([]ConditionInterface{tx.Statement.BuildCondition(query, args...)}))
return tx
}

View File

@ -6,4 +6,6 @@ type Dialect interface {
Query(*DB) error
Update(*DB) error
Delete(*DB) error
Quote(string) string
}

View File

@ -1,63 +1,112 @@
package sqlbuilder
import (
"bytes"
"fmt"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/model"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/inflection"
)
// GetTable get table name for current db operation
func GetTable(tx *gorm.DB) chan string {
tableChan := make(chan string)
func buildCondition(tx *gorm.DB, c gorm.ConditionInterface, s *bytes.Buffer) []interface{} {
args := []interface{}{}
switch cond := c.(type) {
case gorm.And:
s.WriteString("(")
for i, v := range cond {
if i > 0 {
s.WriteString(" AND ")
}
args = append(args, buildCondition(tx, v, s)...)
}
s.WriteString(")")
case gorm.Or:
s.WriteString("(")
for i, v := range cond {
if i > 0 {
s.WriteString(" OR ")
}
args = append(args, buildCondition(tx, v, s)...)
}
s.WriteString(")")
case gorm.Not:
s.WriteString("NOT (")
for i, v := range cond {
if i > 0 {
s.WriteString(" AND ")
}
args = append(args, buildCondition(tx, v, s)...)
}
s.WriteString(")")
case gorm.Raw:
s.WriteString(cond.SQL)
args = append(args, cond.Args...)
case gorm.Eq:
if cond.Value == nil {
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" IS NULL")
} else {
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" = ?")
args = append(args, cond.Value)
}
case gorm.Neq:
if cond.Value == nil {
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" IS NOT NULL")
} else {
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" <> ?")
args = append(args, cond.Value)
}
case gorm.Gt:
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" > ?")
args = append(args, cond.Value)
case gorm.Gte:
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" >= ?")
args = append(args, cond.Value)
case gorm.Lt:
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" < ?")
args = append(args, cond.Value)
case gorm.Lte:
s.WriteString(tx.Dialect().Quote(cond.Column))
s.WriteString(" <= ?")
args = append(args, cond.Value)
default:
if sqlCond, ok := cond.(ConditionInterface); ok {
sql, as := sqlCond.ToSQL(tx)
s.WriteString(sql)
args = append(args, as)
} else {
tx.AddError(fmt.Errorf("unsupported condition: %#v", cond))
}
}
return args
}
// ConditionInterface condition interface
type ConditionInterface interface {
ToSQL(*gorm.DB) (string, []interface{})
}
// BuildConditions build conditions
func BuildConditions(tx *gorm.DB) chan string {
queryChan := make(chan string)
go func() {
var tableName string
if name, ok := tx.Statement.Table.(string); ok {
tableName = name
} else {
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
if v != nil {
if t, ok := v.(tabler); ok {
tableName = t.TableName()
} else if t, ok := v.(dbTabler); ok {
tableName = t.TableName(tx)
} else if s := schema.Parse(v); s != nil {
if s.TableName != "" {
tableName = s.TableName
} else {
tableName = schema.ToDBName(s.ModelType.Name())
if !tx.Config.SingularTable {
tableName = inflection.Plural(tableName)
}
}
}
s := bytes.NewBufferString("")
args := []interface{}{}
if tableName != "" {
break
}
}
for i, c := range tx.Statement.Conditions {
if i > 0 {
s.WriteString(" AND ")
}
}
if tableName != "" {
if model.DefaultTableNameHandler != nil {
tableChan <- model.DefaultTableNameHandler(tx, tableName)
} else {
tableChan <- tableName
}
} else {
tx.AddError(ErrInvalidTable)
args = append(args, buildCondition(tx, c, s)...)
}
}()
return tableChan
}
type tabler interface {
TableName() string
}
type dbTabler interface {
TableName(*gorm.DB) string
return queryChan
}

View File

@ -0,0 +1,63 @@
package sqlbuilder
import (
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/model"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/inflection"
)
// GetTable get table name for current db operation
func GetTable(tx *gorm.DB) chan string {
tableChan := make(chan string)
go func() {
var tableName string
if name, ok := tx.Statement.Table.(string); ok {
tableName = name
} else {
for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} {
if v != nil {
if t, ok := v.(tabler); ok {
tableName = t.TableName()
} else if t, ok := v.(dbTabler); ok {
tableName = t.TableName(tx)
} else if s := schema.Parse(v); s != nil {
if s.TableName != "" {
tableName = s.TableName
} else {
tableName = schema.ToDBName(s.ModelType.Name())
if !tx.Config.SingularTable {
tableName = inflection.Plural(tableName)
}
}
}
if tableName != "" {
break
}
}
}
}
if tableName != "" {
if model.DefaultTableNameHandler != nil {
tableChan <- model.DefaultTableNameHandler(tx, tableName)
} else {
tableChan <- tableName
}
} else {
tx.AddError(ErrInvalidTable)
}
}()
return tableChan
}
type tabler interface {
TableName() string
}
type dbTabler interface {
TableName(*gorm.DB) string
}

View File

@ -6,7 +6,8 @@ import (
"fmt"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/common/destination"
"github.com/jinzhu/gorm/dialects/common/sqlbuilder"
"github.com/jinzhu/gorm/model"
)
// Dialect Sqlite3 Dialect for GORM
@ -23,9 +24,9 @@ func (dialect Dialect) Quote(name string) string {
func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
var (
args []interface{}
assignmentsChan = destination.GetAssignments(tx)
tableNameChan = destination.GetTable(tx)
primaryFields []*destination.Field
assignmentsChan = sqlbuilder.GetAssignmentFields(tx)
tableNameChan = sqlbuilder.GetTable(tx)
primaryFields []*model.Field
)
s := bytes.NewBufferString("INSERT INTO ")
@ -41,7 +42,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) {
valueBuffer := bytes.NewBufferString("VALUES ")
for idx, fields := range assignments {
var primaryField *destination.Field
var primaryField *model.Field
if idx != 0 {
valueBuffer.WriteString(",")
}

View File

@ -105,8 +105,8 @@ func (stmt *Statement) AddConditions(conds ...ConditionInterface) {
// Raw raw sql
type Raw struct {
Value string
Args []interface{} // TODO NamedArg
SQL string
Args []interface{} // TODO NamedArg
}
// Eq equal to
@ -153,7 +153,7 @@ type Lte struct {
type And []ConditionInterface
// Not TRUE if condition is false
type Not ConditionInterface
type Not []ConditionInterface
// Or TRUE if any of the conditions is TRUE
type Or []ConditionInterface