diff --git a/api.go b/api.go index 544cca5e..30e0b762 100644 --- a/api.go +++ b/api.go @@ -10,7 +10,7 @@ func (s *DB) Where(query interface{}, args ...interface{}) *DB { // Not add NOT condition func (s *DB) Not(query interface{}, args ...interface{}) *DB { tx := s.init() - tx.Statement.AddConditions(Not(tx.Statement.BuildCondition(query, args...))) + tx.Statement.AddConditions(Not([]ConditionInterface{tx.Statement.BuildCondition(query, args...)})) return tx } diff --git a/dialect.go b/dialect.go index 19b3fcea..6fb69498 100644 --- a/dialect.go +++ b/dialect.go @@ -6,4 +6,6 @@ type Dialect interface { Query(*DB) error Update(*DB) error Delete(*DB) error + + Quote(string) string } diff --git a/dialects/common/sqlbuilder/sqlbuilder.go b/dialects/common/sqlbuilder/sqlbuilder.go index 5debeaef..0f3fe7ef 100644 --- a/dialects/common/sqlbuilder/sqlbuilder.go +++ b/dialects/common/sqlbuilder/sqlbuilder.go @@ -1,63 +1,112 @@ package sqlbuilder import ( + "bytes" + "fmt" + "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/model" - "github.com/jinzhu/gorm/schema" - "github.com/jinzhu/inflection" ) -// GetTable get table name for current db operation -func GetTable(tx *gorm.DB) chan string { - tableChan := make(chan string) +func buildCondition(tx *gorm.DB, c gorm.ConditionInterface, s *bytes.Buffer) []interface{} { + args := []interface{}{} + + switch cond := c.(type) { + case gorm.And: + s.WriteString("(") + for i, v := range cond { + if i > 0 { + s.WriteString(" AND ") + } + args = append(args, buildCondition(tx, v, s)...) + } + s.WriteString(")") + case gorm.Or: + s.WriteString("(") + for i, v := range cond { + if i > 0 { + s.WriteString(" OR ") + } + args = append(args, buildCondition(tx, v, s)...) + } + s.WriteString(")") + case gorm.Not: + s.WriteString("NOT (") + for i, v := range cond { + if i > 0 { + s.WriteString(" AND ") + } + args = append(args, buildCondition(tx, v, s)...) + } + s.WriteString(")") + case gorm.Raw: + s.WriteString(cond.SQL) + args = append(args, cond.Args...) + case gorm.Eq: + if cond.Value == nil { + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" IS NULL") + } else { + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" = ?") + args = append(args, cond.Value) + } + case gorm.Neq: + if cond.Value == nil { + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" IS NOT NULL") + } else { + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" <> ?") + args = append(args, cond.Value) + } + case gorm.Gt: + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" > ?") + args = append(args, cond.Value) + case gorm.Gte: + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" >= ?") + args = append(args, cond.Value) + case gorm.Lt: + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" < ?") + args = append(args, cond.Value) + case gorm.Lte: + s.WriteString(tx.Dialect().Quote(cond.Column)) + s.WriteString(" <= ?") + args = append(args, cond.Value) + default: + if sqlCond, ok := cond.(ConditionInterface); ok { + sql, as := sqlCond.ToSQL(tx) + s.WriteString(sql) + args = append(args, as) + } else { + tx.AddError(fmt.Errorf("unsupported condition: %#v", cond)) + } + } + + return args +} + +// ConditionInterface condition interface +type ConditionInterface interface { + ToSQL(*gorm.DB) (string, []interface{}) +} + +// BuildConditions build conditions +func BuildConditions(tx *gorm.DB) chan string { + queryChan := make(chan string) go func() { - var tableName string - if name, ok := tx.Statement.Table.(string); ok { - tableName = name - } else { - for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} { - if v != nil { - if t, ok := v.(tabler); ok { - tableName = t.TableName() - } else if t, ok := v.(dbTabler); ok { - tableName = t.TableName(tx) - } else if s := schema.Parse(v); s != nil { - if s.TableName != "" { - tableName = s.TableName - } else { - tableName = schema.ToDBName(s.ModelType.Name()) - if !tx.Config.SingularTable { - tableName = inflection.Plural(tableName) - } - } - } + s := bytes.NewBufferString("") + args := []interface{}{} - if tableName != "" { - break - } - } + for i, c := range tx.Statement.Conditions { + if i > 0 { + s.WriteString(" AND ") } - } - - if tableName != "" { - if model.DefaultTableNameHandler != nil { - tableChan <- model.DefaultTableNameHandler(tx, tableName) - } else { - tableChan <- tableName - } - } else { - tx.AddError(ErrInvalidTable) + args = append(args, buildCondition(tx, c, s)...) } }() - - return tableChan -} - -type tabler interface { - TableName() string -} - -type dbTabler interface { - TableName(*gorm.DB) string + return queryChan } diff --git a/dialects/common/sqlbuilder/table.go b/dialects/common/sqlbuilder/table.go new file mode 100644 index 00000000..5debeaef --- /dev/null +++ b/dialects/common/sqlbuilder/table.go @@ -0,0 +1,63 @@ +package sqlbuilder + +import ( + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/model" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/inflection" +) + +// GetTable get table name for current db operation +func GetTable(tx *gorm.DB) chan string { + tableChan := make(chan string) + + go func() { + var tableName string + if name, ok := tx.Statement.Table.(string); ok { + tableName = name + } else { + for _, v := range []interface{}{tx.Statement.Table, tx.Statement.Dest} { + if v != nil { + if t, ok := v.(tabler); ok { + tableName = t.TableName() + } else if t, ok := v.(dbTabler); ok { + tableName = t.TableName(tx) + } else if s := schema.Parse(v); s != nil { + if s.TableName != "" { + tableName = s.TableName + } else { + tableName = schema.ToDBName(s.ModelType.Name()) + if !tx.Config.SingularTable { + tableName = inflection.Plural(tableName) + } + } + } + + if tableName != "" { + break + } + } + } + } + + if tableName != "" { + if model.DefaultTableNameHandler != nil { + tableChan <- model.DefaultTableNameHandler(tx, tableName) + } else { + tableChan <- tableName + } + } else { + tx.AddError(ErrInvalidTable) + } + }() + + return tableChan +} + +type tabler interface { + TableName() string +} + +type dbTabler interface { + TableName(*gorm.DB) string +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 7c5e9e7a..dd3d2149 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -6,7 +6,8 @@ import ( "fmt" "github.com/jinzhu/gorm" - "github.com/jinzhu/gorm/dialects/common/destination" + "github.com/jinzhu/gorm/dialects/common/sqlbuilder" + "github.com/jinzhu/gorm/model" ) // Dialect Sqlite3 Dialect for GORM @@ -23,9 +24,9 @@ func (dialect Dialect) Quote(name string) string { func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { var ( args []interface{} - assignmentsChan = destination.GetAssignments(tx) - tableNameChan = destination.GetTable(tx) - primaryFields []*destination.Field + assignmentsChan = sqlbuilder.GetAssignmentFields(tx) + tableNameChan = sqlbuilder.GetTable(tx) + primaryFields []*model.Field ) s := bytes.NewBufferString("INSERT INTO ") @@ -41,7 +42,7 @@ func (dialect *Dialect) Insert(tx *gorm.DB) (err error) { valueBuffer := bytes.NewBufferString("VALUES ") for idx, fields := range assignments { - var primaryField *destination.Field + var primaryField *model.Field if idx != 0 { valueBuffer.WriteString(",") } diff --git a/statement.go b/statement.go index cbe7702f..99b48e9e 100644 --- a/statement.go +++ b/statement.go @@ -105,8 +105,8 @@ func (stmt *Statement) AddConditions(conds ...ConditionInterface) { // Raw raw sql type Raw struct { - Value string - Args []interface{} // TODO NamedArg + SQL string + Args []interface{} // TODO NamedArg } // Eq equal to @@ -153,7 +153,7 @@ type Lte struct { type And []ConditionInterface // Not TRUE if condition is false -type Not ConditionInterface +type Not []ConditionInterface // Or TRUE if any of the conditions is TRUE type Or []ConditionInterface