Started with expression extension

Improved questionmark parameter placeholder replacementw

Added support for subqueries in Where and Having clauses

Queries can be transformed into subqueries by calling .Subquery() on a db object

See main_test.go:TestQueryBuilderSubselectInWhere

Fixed comment spacing

Refactoring, adding Having Subquery support, allowing db.T for tablenames

Removed quoting from tablename in db.T, use db.QT for that

Refactoring, adding Having Subquery support, allowing db.T for tablenames

Added changes

Started with expression extension

Refactoring, adding Having Subquery support, allowing db.T for tablenames

Added method to easily update fields of the Model struct

Added column comparison and Join support

Added subquery support for InnerJoin querybuilder

Fixed column comparison

Added support for column prefixes

Models can set their column prefix by implementing the method ColumnPrefix() string

Fixed multi-parameter subselects and introduced aliasing

Improved Related method

Improved Related method to search for foreign key struct fields with the suffix "ID" (additional to "Id")

Got QueryExpr support from upstream

Added support for subqueries in Where and Having clauses

Queries can be transformed into subqueries by calling .Subquery() on a db object

See main_test.go:TestQueryBuilderSubselectInWhere

Improved questionmark parameter placeholder replacementw

Refactoring, adding Having Subquery support, allowing db.T for tablenames

Removed quoting from tablename in db.T, use db.QT for that

Removed quoting from tablename in db.T, use db.QT for that

Added changes

Added method to easily update fields of the Model struct

Fixed column comparison

Added support for column prefixes

Models can set their column prefix by implementing the method ColumnPrefix() string

Fixed multi-parameter subselects and introduced aliasing

Improved Related method

Improved Related method to search for foreign key struct fields with the suffix "ID" (additional to "Id")

Added select extension for multiple columns

Added support for LEFT RIGHT OUTER joins

Fixed slice support for lexpr.In()

Publizised LExpr

Added DateFormatting for all dialects

Added SUM function for columns

Fixed FormatDate

Added count for column

Removed literal expressions LExpr

Rewrote LExpr methods to work with expr structs.

Added methods BAnd and BOr (bitwise & and | )

Added SetLogWriter method

Added NotIn query expression

Added Distinct query expression

Added DistinctColumn query expression

Same as Distinct but returns a string

Added method OnExp to jexpr

Improved query expression .Eq() method for nil pointers

Fixed rebase errors
This commit is contained in:
Gerhard Gruber 2017-07-27 17:14:00 +02:00 committed by daniel
parent 48a20a6e9f
commit c807fe3202
13 changed files with 636 additions and 201 deletions

View File

@ -50,6 +50,19 @@ type Dialect interface {
// CurrentDatabase return current database name
CurrentDatabase() string
// Formates the date and tries to be as similar as possible with all databases
// | | gorm | SQLITE | MYSQL | MSSQL | POSTGRES |
// | ----------- | ---- | ----------------- | ----------------- | ---------------- | ---------------- |
// | YEAR | %y | %Y (0000-9999) | %Y (0000-9999) | YYYY (0000-9999) | YYYY (0000-9999) |
// | MONTH | %m | %m (01-12) | %m (01-12) | MM (00-12) | MM (0000-9999) |
// | WEEK | %w | %W (00-53) | %u (00-53) | -- | WW (1-53) |
// | DAY | %d | %d (00-31) | %d (00-31) | dd (00-31) | DD (00-31) |
// | DAY OF WEEK | %D | %w (SUN 0-6 SAT) | %w (SUN 0-6 SAT) | -- | D (SUN 1-7 SAT) |
// | HOUR | %h | %H (00-24) | %H (00-23) | HH (00-31) | HH24 (00-23) |
// | MINUTE | %m | %M (00-59) | %i (00-59) | mm (00-59) | MI (00-59) |
// | SECOND | %s | %S (00-59) | %S (00-59) | ss (00-59) | SS (00-59) |
FormatDate(*expr, string) *expr
}
var dialectsMap = map[string]Dialect{}
@ -128,3 +141,29 @@ func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string)
}
return dialect.CurrentDatabase(), tableName
}
func parseDateFormat(format string, mapping map[rune]string) string {
var parsedFormat string
isFormatter := false
for _, rune := range format {
if !isFormatter {
if rune == '%' {
isFormatter = true
} else {
parsedFormat += string(rune)
}
continue
}
isFormatter = false
if sign, ok := mapping[rune]; ok {
parsedFormat += sign
} else {
parsedFormat += "%" + string(rune)
}
}
return parsedFormat
}

View File

@ -174,3 +174,7 @@ func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...str
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
func (commonDialect) FormatDate(e *expr, format string) *expr {
return e
}

209
dialect_mssql.go Normal file
View File

@ -0,0 +1,209 @@
package gorm
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
func setIdentityInsert(scope *Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
scope.InstanceSet("mssql:identity_insert_on", true)
}
}
}
}
func turnOffIdentityInsert(scope *Scope) {
if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
}
}
}
func init() {
DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
RegisterDialect("mssql", &mssql{})
}
type mssql struct {
db SQLCommon
DefaultForeignKeyNamer
}
func (mssql) GetName() string {
return "mssql"
}
func (s *mssql) SetDB(db SQLCommon) {
s.db = db
}
func (mssql) BindVar(i int) string {
return "$$$" // ?
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`[%s]`, key)
}
func (s *mssql) DataTypeOf(field *StructField) string {
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "float"
case reflect.String:
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("nvarchar(%d)", size)
} else {
sqlType = "nvarchar(max)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetimeoffset"
}
default:
if IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "varbinary(max)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) fieldCanAutoIncrement(field *StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return value != "FALSE"
}
return field.IsPrimaryKey
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s mssql) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
}
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
}
}
return
}
func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (mssql) DefaultValueStr() string {
return "DEFAULT VALUES"
}
func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}
func (mssql) FormatDate(e *expr, format string) *expr {
mapping := map[rune]string{
'y': "YYYY",
'm': "MM",
'd': "dd",
'h': "HH",
'M': "mm",
's': "ss",
}
parsedFormat := parseDateFormat(format, mapping)
e.expr = "(format(" + e.expr + ", '" + parsedFormat + "'))"
return e
}

View File

@ -189,3 +189,20 @@ func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string {
func (mysql) DefaultValueStr() string {
return "VALUES()"
}
func (mysql) FormatDate(e *expr, format string) *expr {
mapping := map[rune]string{
'y': "%Y",
'm': "%m",
'w': "%u",
'd': "%d",
'D': "%w",
'h': "%H",
'M': "%i",
's': "%S",
}
parsedFormat := parseDateFormat(format, mapping)
e.expr = "(DATE_FORMAT(" + e.expr + ", '" + parsedFormat + "'))"
return e
}

View File

@ -141,3 +141,20 @@ func isJSON(value reflect.Value) bool {
_, ok := value.Interface().(json.RawMessage)
return ok
}
func (postgres) FormatDate(e *expr, format string) *expr {
mapping := map[rune]string{
'y': "YYYY",
'm': "MM",
'w': "WW",
'd': "DD",
'D': "D",
'h': "HH24",
'M': "MI",
's': "SS",
}
parsedFormat := parseDateFormat(format, mapping)
e.expr = "(to_char(" + e.expr + ", '" + parsedFormat + "'))"
return e
}

View File

@ -105,3 +105,20 @@ func (s sqlite3) CurrentDatabase() (name string) {
}
return
}
func (sqlite3) FormatDate(e *expr, format string) *expr {
mapping := map[rune]string{
'y': "%Y",
'm': "%m",
'w': "%W",
'd': "%d",
'D': "%w",
'h': "%H",
'M': "%M",
's': "%S",
}
parsedFormat := parseDateFormat(format, mapping)
e.expr = "(strftime('" + parsedFormat + "', " + e.expr + "))"
return e
}

View File

@ -1,196 +1,3 @@
package mssql
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
func setIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
for _, field := range scope.PrimaryFields() {
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsBlank {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
scope.InstanceSet("mssql:identity_insert_on", true)
}
}
}
}
func turnOffIdentityInsert(scope *gorm.Scope) {
if scope.Dialect().GetName() == "mssql" {
if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
}
}
}
func init() {
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
gorm.RegisterDialect("mssql", &mssql{})
}
type mssql struct {
db gorm.SQLCommon
gorm.DefaultForeignKeyNamer
}
func (mssql) GetName() string {
return "mssql"
}
func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db
}
func (mssql) BindVar(i int) string {
return "$$$" // ?
}
func (mssql) Quote(key string) string {
return fmt.Sprintf(`[%s]`, key)
}
func (s *mssql) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "int IDENTITY(1,1)"
} else {
sqlType = "int"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
sqlType = "bigint IDENTITY(1,1)"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "float"
case reflect.String:
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("nvarchar(%d)", size)
} else {
sqlType = "nvarchar(max)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetimeoffset"
}
default:
if gorm.IsByteArrayOrSlice(dataValue) {
if size > 0 && size < 8000 {
sqlType = fmt.Sprintf("varbinary(%d)", size)
} else {
sqlType = "varbinary(max)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
return value != "FALSE"
}
return field.IsPrimaryKey
}
func (s mssql) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
return count > 0
}
func (s mssql) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
return err
}
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s mssql) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
return count > 0
}
func (s mssql) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
return err
}
func (s mssql) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"
}
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
}
}
return
}
func (mssql) SelectFromDummyTable() string {
return ""
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (mssql) DefaultValueStr() string {
return "DEFAULT VALUES"
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}
import _ "github.com/denisenkom/go-mssqldb"

277
expression_ext.go Normal file
View File

@ -0,0 +1,277 @@
package gorm
import (
"reflect"
"strings"
)
type jexpr struct {
expr string
args []interface{}
}
func join(joinType string, db *DB, model interface{}, alias ...string) *jexpr {
var al string
if len(alias) > 0 {
al = alias[0]
}
if val, ok := model.(*expr); ok {
return &jexpr{expr: " " + joinType + " JOIN (" + val.expr + ") " + al, args: val.args}
}
return &jexpr{expr: " " + joinType + " JOIN " + db.T(model) + " " + al}
}
func (db *DB) InnerJoin(model interface{}, alias ...string) *jexpr {
return join("INNER", db, model, alias...)
}
func (db *DB) LeftJoin(model interface{}, alias ...string) *jexpr {
return join("LEFT", db, model, alias...)
}
func (db *DB) RightJoin(model interface{}, alias ...string) *jexpr {
return join("RIGHT", db, model, alias...)
}
func (db *DB) OuterJoin(model interface{}, alias ...string) *jexpr {
return join("OUTER", db, model, alias...)
}
func (je *jexpr) On(col1 *expr, col2 *expr) *expr {
return &expr{expr: je.expr + " ON " + col1.expr + " = " + col2.expr, args: je.args}
}
func (je *jexpr) OnExp(e2 *expr) *expr {
e := &expr{expr: je.expr + " ON " + e2.expr, args: je.args}
e.args = append(e.args, e2.args...)
return e
}
func (db *DB) L(model interface{}, name string) *expr {
scope := db.NewScope(model)
field, _ := scope.FieldByName(name)
return &expr{expr: scope.Quote(scope.TableName()) + "." + scope.Quote(field.DBName)}
}
func (db *DB) LA(model interface{}, alias string, name string) *expr {
scope := db.NewScope(model)
field, _ := scope.FieldByName(name)
return &expr{expr: scope.Quote(alias) + "." + scope.Quote(field.DBName)}
}
func (db *DB) C(model interface{}, names ...string) string {
columns := make([]string, 0)
scope := db.NewScope(model)
for _, name := range names {
field, _ := scope.FieldByName(name)
columns = append(columns, field.DBName)
}
return strings.Join(columns, ", ")
}
func (db *DB) CA(model interface{}, alias string, names ...string) string {
columns := make([]string, 0)
for _, name := range names {
columns = append(columns, db.LA(model, alias, name).expr)
}
return strings.Join(columns, ", ")
}
func (db *DB) CQ(model interface{}, names ...string) string {
columns := make([]string, 0)
for _, name := range names {
columns = append(columns, db.L(model, name).expr)
}
return strings.Join(columns, ", ")
}
func (db *DB) T(model interface{}) string {
scope := db.NewScope(model)
return scope.TableName()
}
func (db *DB) QT(model interface{}) string {
scope := db.NewScope(model)
return scope.QuotedTableName()
}
func (e *expr) operator(operator string, value interface{}) *expr {
if value == nil {
e.expr = "(" + e.expr + " " + operator + " )"
return e
}
if _, ok := value.(*expr); ok {
e.expr = "(" + e.expr + " " + operator + " (?))"
} else {
e.expr = "(" + e.expr + " " + operator + " ?)"
}
e.args = append(e.args, value)
return e
}
func (e *expr) Gt(value interface{}) *expr {
return e.operator(">", value)
}
func (e *expr) Ge(value interface{}) *expr {
return e.operator(">=", value)
}
func (e *expr) Lt(value interface{}) *expr {
return e.operator("<", value)
}
func (e *expr) Le(value interface{}) *expr {
return e.operator("<=", value)
}
func (e *expr) BAnd(value interface{}) *expr {
return e.operator("&", value)
}
func (e *expr) BOr(value interface{}) *expr {
return e.operator("|", value)
}
func (e *expr) Like(value interface{}) *expr {
return e.operator("LIKE", value)
}
func (e *expr) Eq(value interface{}) *expr {
if value == nil {
return e.operator("IS NULL", value)
} else if val := reflect.ValueOf(value); val.Kind() == reflect.Ptr && val.IsNil() {
return e.operator("IS NULL", nil)
}
return e.operator("=", value)
}
func (e *expr) Neq(value interface{}) *expr {
if value == nil {
return e.operator("IS NOT NULL", value)
}
return e.operator("!=", value)
}
func (e *expr) Sum() string {
return "SUM(" + e.expr + ")"
}
func (e *expr) Count() string {
return "COUNT(" + e.expr + ")"
}
func (e *expr) Distinct() *expr {
e.expr = "DISTINCT " + e.expr
return e
}
func (e *expr) DistinctColumn() string {
return "DISTINCT " + e.expr
}
func (e *expr) in(operator string, values ...interface{}) *expr {
// NOTE: Maybe there is a better way to do this? :)
if len(values) == 1 {
if s := reflect.ValueOf(values[0]); s.Kind() == reflect.Slice {
vals := make([]interface{}, s.Len())
qm := make([]string, s.Len())
for i := 0; i < s.Len(); i++ {
vals[i] = s.Index(i).Interface()
qm[i] = "?"
}
e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))"
e.args = append(e.args, vals...)
return e
}
}
qm := make([]string, len(values))
for i := 0; i < len(values); i++ {
qm[i] = "?"
}
e.expr = "(" + e.expr + operator + " IN (" + strings.Join(qm, ",") + "))"
e.args = append(e.args, values...)
return e
}
func (e *expr) In(values ...interface{}) *expr {
return e.in("", values...)
}
func (e *expr) NotIn(values ...interface{}) *expr {
return e.in(" NOT", values...)
}
func (e *expr) OrderAsc() string {
return e.expr + " ASC "
}
func (e *expr) OrderDesc() string {
return e.expr + " DESC "
}
func (e *expr) Or(e2 *expr) *expr {
e.expr = "(" + e.expr + " OR " + e2.expr + ")"
e.args = append(e.args, e2.args...)
return e
}
func (e *expr) And(e2 *expr) *expr {
e.expr = "(" + e.expr + " AND " + e2.expr + ")"
e.args = append(e.args, e2.args...)
return e
}
func (db *DB) UpdateFields(fields ...string) *DB {
sets := make(map[string]interface{})
m := reflect.ValueOf(db.Value).Elem()
for _, field := range fields {
sets[db.C(db.Value, field)] = m.FieldByName(field).Interface()
}
return db.Update(sets)
}
func (db *DB) SelectFields(fields ...string) *DB {
selects := strings.Join(fields, ", ")
return db.Select(selects)
}
func (e *expr) Intersect(e2 *expr) *expr {
e.expr = "((" + e.expr + ") INTERSECT (" + e2.expr + "))"
e.args = append(e.args, e2.args...)
return e
}
func (e *expr) Alias(alias string) *expr {
e.expr = e.expr + " " + alias + " "
return e
}
func (db *DB) FormatDate(e *expr, format string) *expr {
return db.Dialect().FormatDate(e, format)
}
func (db *DB) FormatDateColumn(e *expr, format string) string {
return db.FormatDate(e, format).expr
}

13
main.go
View File

@ -133,6 +133,11 @@ func (s *DB) SetLogger(log logger) {
s.logger = log
}
// SetLogWriter sets the LogWriter the default logger should write to
func (s *DB) SetLogWriter(log LogWriter) {
s.logger = Logger{log}
}
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
func (s *DB) LogMode(enable bool) *DB {
if enable {
@ -169,11 +174,15 @@ func (s *DB) NewScope(value interface{}) *Scope {
}
// QueryExpr returns the query as expr object
func (s *DB) QueryExpr() *expr {
func (s *DB) QueryExpr(alias ...string) *expr {
scope := s.NewScope(s.Value)
scope.InstanceSet("skip_bindvar", true)
scope.prepareQuerySQL()
if len(alias) > 0 {
return Expr("( "+scope.SQL+" ) "+alias[0]+" ", scope.SQLVars...)
}
return Expr(scope.SQL, scope.SQLVars...)
}
@ -242,7 +251,7 @@ func (s *DB) Having(query interface{}, values ...interface{}) *DB {
// Joins specify Joins conditions
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (s *DB) Joins(query string, args ...interface{}) *DB {
func (s *DB) Joins(query interface{}, args ...interface{}) *DB {
return s.clone().search.Joins(query, args...).db
}

View File

@ -629,6 +629,7 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) {
if len(users) != 2 {
t.Errorf("Two users should be found, instead found %d", len(users))
}
DB.Delete(&User{})
}
func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
@ -689,6 +690,14 @@ func TestQueryBuilderSubselectInHaving(t *testing.T) {
if len(users) != 1 {
t.Errorf("Two user group should be found, instead found %d", len(users))
}
DB.Select("*").Where("name LIKE ?", "query_expr_having_%").Where("age >= (?)", DB.
Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users)
if len(users) != 2 {
t.Errorf("Two users should be found, instead found %d", len(users))
}
DB.Delete(&User{})
}
func DialectHasTzSupport() bool {

View File

@ -158,6 +158,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
modelStruct.ModelType = reflectType
var columnPrefix string
if columnPrefixer, ok := reflect.New(modelStruct.ModelType).Interface().(columnPrefixer); ok {
columnPrefix = columnPrefixer.ColumnPrefix()
}
// Get all fields
for i := 0; i < reflectType.NumField(); i++ {
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
@ -587,7 +592,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
if value, ok := field.TagSettings["COLUMN"]; ok {
field.DBName = value
} else {
field.DBName = ToDBName(fieldStruct.Name)
field.DBName = columnPrefix + ToDBName(fieldStruct.Name)
}
modelStruct.StructFields = append(modelStruct.StructFields, field)

View File

@ -142,6 +142,7 @@ func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
if field.Name == name || field.DBName == name {
return field, true
}
if field.DBName == dbName {
mostMatchedField = field
}
@ -274,6 +275,7 @@ func (scope *Scope) AddToVars(value interface{}) string {
if skipBindVar {
return "?"
}
return scope.Dialect().BindVar(len(scope.SQLVars))
}
@ -306,6 +308,10 @@ type tabler interface {
TableName() string
}
type columnPrefixer interface {
ColumnPrefix() string
}
type dbTabler interface {
TableName(*DB) string
}
@ -339,7 +345,7 @@ func (scope *Scope) QuotedTableName() (name string) {
return scope.Quote(scope.TableName())
}
// CombinedConditionSql return combined condition sql
// CombinedConditionSQL return combined condition SQL
func (scope *Scope) CombinedConditionSql() string {
joinSQL := scope.joinsSQL()
whereSQL := scope.whereSQL()
@ -473,6 +479,18 @@ func (scope *Scope) quoteIfPossible(str string) string {
return str
}
func (scope *Scope) replaceParameterPlaceholderLiteral(sql string, parameter interface{}, addToVars bool) string {
if val, ok := parameter.(string); ok && !addToVars {
return strings.Replace(sql, "?", val, 1)
}
return strings.Replace(sql, "?", scope.AddToVars(parameter), 1)
}
func (scope *Scope) replaceParameterPlaceholder(sql string, parameter interface{}) string {
return scope.replaceParameterPlaceholderLiteral(sql, parameter, true)
}
func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
var (
ignored interface{}
@ -578,6 +596,9 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
}
}
return strings.Join(sqls, " AND ")
case *expr:
clause["args"] = []interface{}{value}
str = "?"
case interface{}:
var sqls []string
newScope := scope.New(value)
@ -1048,7 +1069,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
tx := scope.db.Set("gorm:association:source", scope.Value)
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"ID", scope.typeName()+"ID", toScope.typeName()+"Id", scope.typeName()+"Id") {
fromField, _ := scope.FieldByName(foreignKey)
toField, _ := toScope.FieldByName(foreignKey)

View File

@ -106,8 +106,12 @@ func (s *search) Having(query interface{}, values ...interface{}) *search {
return s
}
func (s *search) Joins(query string, values ...interface{}) *search {
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
func (s *search) Joins(query interface{}, values ...interface{}) *search {
if val, ok := query.(*expr); ok {
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": val.expr, "args": val.args})
} else {
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
}
return s
}